├── .github
└── workflows
│ └── python-app.yml
├── .gitignore
├── LICENSE
├── README.md
├── deepinsight
├── __init__.py
├── analyse.py
├── architecture.py
├── preprocess.py
├── train.py
├── util
│ ├── __init__.py
│ ├── custom_losses.py
│ ├── data_generator.py
│ ├── hdf5.py
│ ├── opts.py
│ ├── stats.py
│ ├── tetrode.py
│ └── wavelet_transform.py
└── visualize.py
├── media
├── colab_walkthrough.gif
├── decoding_error.gif
└── model_architecture.png
├── notebooks
├── deepinsight_calcium_example.ipynb
├── example_data
│ └── calcium
│ │ └── calcium_rois.jpg
└── static
│ ├── calcium_example.ipynb
│ └── ephys_example.ipynb
├── requirements.txt
├── setup.py
└── tests
├── __init__.py
├── run_test.py
└── tests.ipynb
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: build
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python 3.7
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: 3.7
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install pytest
27 | pip install -e git+https://github.com/CYHSM/DeepInsight.git#egg=DeepInsight
28 | pip install git+https://github.com/CYHSM/wavelets
29 | #- name: Lint with flake8
30 | # run: |
31 | # # stop the build if there are Python syntax errors or undefined names
32 | # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
33 | # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
34 | # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
35 | - name: Test with pytest
36 | run: |
37 | pytest tests/run_test.py
38 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.p
3 | *.h5
4 | *.hdf5
5 | logs*
6 | .ipynb_checkpoints
7 | *.mp4
8 | /data
9 | .vscode
10 | *.mat
11 | *.html
12 |
13 | # Distribution / packaging
14 | .Python
15 | env/
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | notebooks/private/
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Markus Frey
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://github.com/CYHSM/DeepInsight/blob/master/LICENSE.md)
2 | 
3 | 
4 |
5 | # DeepInsight: A general framework for interpreting wide-band neural activity
6 |
7 | DeepInsight is a toolbox for the analysis and interpretation of wide-band neural activity and can be applied on unsorted neural data. This means the traditional step of spike-sorting can be omitted and the raw data can be used directly as input, providing a more objective way of measuring decoding performance.
8 | 
9 |
10 | ## Google Colaboratory
11 |
12 | We created a Colab notebook to showcase how to analyse your own two-photon calcium imaging data. We provide the raw as well as the preprocessed dataset as downloads if you just want to train the model. You can replace the code which loads the traces with your own data handling and directly train it to decode your behaviour or stimuli in the browser.
13 |
14 | [](https://colab.research.google.com/drive/11RXK7JIgVM8Zy9M7xEtt1k62i3JXbZLU)
15 | 
16 |
17 | ## Example Usage
18 | ```python
19 | import deepinsight
20 |
21 | # Load your electrophysiological or calcium-imaging data
22 | (raw_data,
23 | raw_timestamps,
24 | output,
25 | output_timestamps,
26 | info) = deepinsight.util.tetrode.read_tetrode_data(fp_raw_file)
27 |
28 | # Transform raw data to frequency domain
29 | deepinsight.preprocess.preprocess_input(fp_deepinsight, raw_data, sampling_rate=info['sampling_rate'],
30 | channels=info['channels'])
31 |
32 | # Prepare outputs
33 | deepinsight.util.tetrode.preprocess_output(fp_deepinsight, raw_timestamps, output, output_timestamps,
34 | sampling_rate=info['sampling_rate'])
35 |
36 | # Train the model
37 | deepinsight.train.run_from_path(fp_deepinsight, loss_functions, loss_weights)
38 |
39 | # Get loss and shuffled loss for influence plot
40 | losses, output_predictions, indices = deepinsight.analyse.get_model_loss(fp_deepinsight, stepsize=10)
41 | shuffled_losses = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=1, stepsize=10)
42 |
43 | # Plot influence across behaviours
44 | deepinsight.visualize.plot_residuals(fp_deepinsight, frequency_spacing=2)
45 | ```
46 |
47 | See also the [jupyter notebook](notebooks/static/ephys_example.ipynb) for a full example for decoding behaviours from tetrode CA1 recordings. Note that the static notebook does not include interactive plots as shown in the above Colab notebook. The expected run time for a high sampling rate dataset (e.g. tetrode recordings) is highly dependend on the number of channels and duration of experiment. Preprocessing can take up to one day for a 128 channel - 1 hour experiment, while training the model takes between 6 and 12 hours. For calcium recordings the preprocessing time is shrunk down to minutes.
48 |
49 | Following Video shows the performance of the model trained on position (left), head direction (top right) and speed (bottom right):
50 | 
51 |
52 | ## Installation
53 | Install DeepInsight with the following command (Installation time ~ 2 minutes, depending on internet speed):
54 | ```
55 | pip install git+https://github.com/CYHSM/DeepInsight.git
56 | ```
57 |
58 | If you prefer to use DeepInsight from within your browser, we provide Colab-Notebooks to guide you through how to use DeepInsight with your own data.
59 |
60 | - How to use DeepInsight with two-photon calcium imaging data [](https://colab.research.google.com/drive/11RXK7JIgVM8Zy9M7xEtt1k62i3JXbZLU)
61 |
62 | - How to use DeepInsight with electrophysiology data [](https://colab.research.google.com/drive/1h3RYr3r0Zs2k6I53bTiYRq_6VQo38iMP)
63 |
64 | ## System Requirements
65 |
66 | ### Hardware requirements
67 | For preprocessing raw data with a high sampling rate it is recommended to at least use 4 parallel cores. For calcium recordings one core is enough. For training the model it is recommended to use a GPU with at least 6Gb of memory.
68 |
69 | ### Software requirements
70 | The following python dependencies are being automatically installed when installing DeepInsight (specified in requirements.txt):
71 | ```
72 | tensorflow-gpu (2.1.0)
73 | numpy (1.18.1)
74 | pandas (1.0.1)
75 | joblib (0.14.1)
76 | seaborn (0.10.0)
77 | matplotlib (3.1.3)
78 | h5py (2.10.0)
79 | scipy (1.4.1)
80 | ipython (7.12.0)
81 | ```
82 | Version in parentheses indicate the ones used for testing the framework. Its extensively tested on Linux 16.04 but should run on all OS (Windows, Mac, Linux) supporting a Python version >3.6 and pip. It is recommended to install the framework and dependencies in a virtual environment (e.g. conda).
--------------------------------------------------------------------------------
/deepinsight/__init__.py:
--------------------------------------------------------------------------------
1 | from . import util
2 | from . import preprocess
3 | from . import architecture
4 | from . import train
5 | from . import analyse
6 | from . import visualize
7 |
--------------------------------------------------------------------------------
/deepinsight/analyse.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | from . import util
8 | import h5py
9 | import numpy as np
10 | import pandas as pd
11 | from scipy.stats import spearmanr
12 | from tensorflow.compat.v1.keras.backend import clear_session, placeholder, get_session
13 | import os
14 |
15 | import tensorflow as tf
16 | tf.compat.v1.disable_eager_execution()
17 |
18 |
19 | def get_model_loss(fp_hdf_out, stepsize=1, shuffles=None, axis=0, verbose=1, fp_test=None, timestep=-1):
20 | """
21 | Loops across cross validated models and calculates loss and predictions for full experiment length
22 |
23 | Parameters
24 | ----------
25 | fp_hdf_out : str
26 | File path to HDF5 file
27 | stepsize : int, optional
28 | Determines how many samples will be evaluated. 1 -> N samples evaluated,
29 | 2 -> N/2 samples evaluated, etc..., by default 1
30 | shuffles : dict, optional
31 | If wavelets should be shuffled, important for calculating influence scores, by default None
32 |
33 | Returns
34 | -------
35 | losses : (N,1) array_like
36 | Loss between predicted and ground truth observation
37 | predictions : dict
38 | Dictionary with predictions for each behaviour, each item in dict has size (N, Z) with Z the dimensions of the sample (e.g. Z_position=2, Z_speed=1, ...)
39 | indices : (N,1) array_like
40 | Indices which were evaluated, important when taking stepsize unequal to 1
41 | """
42 | dirname = os.path.dirname(fp_hdf_out)
43 | if fp_test is None:
44 | filename = os.path.basename(fp_hdf_out)[0:-3]
45 | else:
46 | filename = os.path.basename(fp_test)[0:-3]
47 | cv_results = []
48 | (_, _, _, opts) = util.hdf5.load_model_with_opts(
49 | dirname + '/models/' + filename + '_model_{}.h5'.format(0))
50 | loss_names = opts['loss_names']
51 | time_shift = opts['model_timesteps']
52 | if verbose > 0:
53 | progress_bar = tf.keras.utils.Progbar(
54 | opts['num_cvs'], width=30, verbose=1, interval=0.05, unit_name='run')
55 | for k in range(0, opts['num_cvs']):
56 | clear_session()
57 | # Find folders
58 | model_path = dirname + '/models/' + filename + '_model_{}.h5'.format(k)
59 | # Load model and generators
60 | (model, training_generator, testing_generator, opts) = util.hdf5.load_model_with_opts(model_path)
61 | if fp_test is not None:
62 | opts['fp_hdf_out'] = fp_hdf_out
63 | opts['handle_nan'] = False
64 | hdf5_file = h5py.File(fp_hdf_out, mode='r')
65 | wavelets = hdf5_file['inputs/wavelets'][()]
66 | hdf5_file.close()
67 | opts['training_indices'] = np.arange(0,wavelets.shape[0] - (opts['model_timesteps'] * opts['batch_size']))
68 | opts['testing_indices'] = np.arange(0,wavelets.shape[0] - (opts['model_timesteps'] * opts['batch_size']))
69 | (training_generator, testing_generator) = util.data_generator.create_train_and_test_generators(opts)
70 | #testing_generator.cv_indices = np.arange(0, testing_generator.wavelets.shape[0] - (opts['model_timesteps'] * opts['batch_size']))
71 | # -----------------------------------------------------------------------------------------------
72 | if shuffles is not None:
73 | testing_generator = shuffle_wavelets(
74 | training_generator, testing_generator, shuffles)
75 | losses, predictions, indices = calculate_losses_from_generator(
76 | testing_generator, model, verbose=verbose-1, stepsize=stepsize)
77 | # -----------------------------------------------------------------------------------------------
78 | cv_results.append((losses, predictions, indices))
79 | if verbose > 0:
80 | progress_bar.add(1)
81 | cv_results = np.array(cv_results, dtype='object')
82 | # Reshape cv_results
83 | losses = np.concatenate(cv_results[:, 0], axis=0)
84 | predictions = {k: [] for k in loss_names}
85 | for out in cv_results[:, 1]:
86 | for p, name in zip(out, loss_names):
87 | predictions[name].append(p)
88 | for key, item in predictions.items():
89 | tmp_output = np.concatenate(predictions[key], axis=0)[:, timestep, :]
90 | predictions[key] = tmp_output
91 | indices = np.concatenate(cv_results[:, 2], axis=0)
92 | # We only take the last timestep for decoding, so decoder does not see any part of the future
93 | # We also need to shift the indices as we decode samples within the time window
94 | time_shifts = np.arange(0, testing_generator.model_timesteps + 1, testing_generator.average_output)[1::] - 1
95 | indices = indices + time_shifts[timestep]
96 |
97 | # Also save to HDF5
98 | hdf5_file = h5py.File(fp_hdf_out, mode='a')
99 | for key, item in predictions.items():
100 | util.hdf5.create_or_update(hdf5_file, dataset_name="analysis/predictions/{}_axis{}_stepsize{}".format(key, axis, stepsize),
101 | dataset_shape=item.shape, dataset_type=np.float32, dataset_value=item)
102 | util.hdf5.create_or_update(hdf5_file, dataset_name="analysis/losses_axis{}_stepsize{}".format(axis, stepsize),
103 | dataset_shape=losses.shape, dataset_type=np.float32, dataset_value=losses)
104 | util.hdf5.create_or_update(hdf5_file, dataset_name="analysis/indices_axis{}_stepsize{}".format(axis, stepsize),
105 | dataset_shape=indices.shape, dataset_type=np.int64, dataset_value=indices)
106 |
107 | # Add real to output of this function
108 | output_real = dict()
109 | for idx, (key, y_pred) in enumerate(predictions.items()):
110 | output_real[key] = np.array(hdf5_file['outputs/{}'.format(key)])[indices, ...]
111 |
112 | hdf5_file.close()
113 |
114 | # Report model performance
115 | if verbose > 0:
116 | df_stats = calculate_model_stats(losses, predictions, indices, output_real)
117 | print(df_stats)
118 |
119 | return losses, predictions, indices, output_real
120 |
121 |
122 | def get_shuffled_model_loss(fp_hdf_out, stepsize=1, axis=0, verbose=1):
123 | """
124 | Shuffles the wavelets and recalculates error
125 |
126 | Parameters
127 | ----------
128 | fp_hdf_out : str
129 | File path to HDF5 file
130 | stepsize : int, optional
131 | Determines how many samples will be evaluated. 1 -> N samples evaluated,
132 | 2 -> N/2 samples evaluated, etc..., by default 1
133 | axis : int, optional
134 | Which axis to shuffle
135 |
136 | Returns
137 | -------
138 | shuffled_losses : (N,1) array_like
139 | Loss between predicted and ground truth observation for shuffled wavelets on specified axis
140 | """
141 | if axis == 0:
142 | raise ValueError(
143 | 'Shuffling across time dimension (axis=0) not supported yet.')
144 | hdf5_file = h5py.File(fp_hdf_out, mode='r')
145 | tmp_wavelets_shape = hdf5_file['inputs/wavelets'].shape
146 | hdf5_file.close()
147 | shuffled_losses = []
148 | if verbose > 0:
149 | progress_bar = tf.keras.utils.Progbar(
150 | tmp_wavelets_shape[axis], width=30, verbose=1, interval=0.05, unit_name='run')
151 | for s in range(0, tmp_wavelets_shape[axis]):
152 | if axis == 1:
153 | losses = get_model_loss(fp_hdf_out, stepsize=stepsize, shuffles={'f': s}, axis=axis, verbose=0)[0]
154 | elif axis == 2:
155 | losses = get_model_loss(fp_hdf_out, stepsize=stepsize, shuffles={'c': s}, axis=axis, verbose=0)[0]
156 | shuffled_losses.append(losses)
157 | if verbose > 0:
158 | progress_bar.add(1)
159 | shuffled_losses = np.array(shuffled_losses)
160 | # Also save to HDF5
161 | hdf5_file = h5py.File(fp_hdf_out, mode='a')
162 | util.hdf5.create_or_update(hdf5_file, dataset_name="analysis/influence/shuffled_losses_axis{}_stepsize{}".format(axis, stepsize),
163 | dataset_shape=shuffled_losses.shape, dataset_type=np.float32, dataset_value=shuffled_losses)
164 | hdf5_file.close()
165 |
166 | return shuffled_losses
167 |
168 |
169 | def calculate_losses_from_generator(tg, model, num_steps=None, stepsize=1, verbose=0):
170 | """
171 | Keras evaluate_generator only returns a scalar loss (mean) while predict_generator only returns the predictions but not the real labels
172 | TODO Make it batch size independent
173 |
174 | Parameters
175 | ----------
176 | tg : object
177 | Data generator
178 | model : object
179 | Keras model
180 | num_steps : int, optional
181 | How many steps should be evaluated, by default None (runs through full experiment)
182 | stepsize : int, optional
183 | Determines how many samples will be evaluated. 1 -> N samples evaluated,
184 | 2 -> N/2 samples evaluated, etc..., by default 1
185 | verbose : int, optional
186 | Verbosity level
187 |
188 | Returns
189 | -------
190 | losses : (N,1) array_like
191 | Loss between predicted and ground truth observation
192 | predictions : dict
193 | Dictionary with predictions for each behaviour, each item in dict has size (N, Z) with Z the dimensions of the sample (e.g. Z_position=2, Z_speed=1, ...)
194 | indices : (N,1) array_like
195 | Indices which were evaluated, important when taking stepsize unequal to 1
196 | """
197 | # X.) Parse inputs
198 | if num_steps is None:
199 | num_steps = len(tg)
200 |
201 | # 1.) Make a copy and adjust attributes
202 | tmp_dict = tg.__dict__.copy()
203 | if tg.batch_size != 1:
204 | tg.batch_size = 1
205 | tg.random_batches = False
206 | tg.shuffle = False
207 | tg.sample_size = tg.model_timesteps * tg.batch_size
208 |
209 | # 2.) Get output tensors
210 | sess = get_session()
211 | (_, test_out) = tg.__getitem__(0)
212 | real_tensor, calc_tensors = placeholder(), []
213 | for output_index in range(0, len(test_out)):
214 | prediction_tensor = model.outputs[output_index]
215 | loss_tensor = model.loss_functions[output_index].fn(
216 | real_tensor, prediction_tensor)
217 | calc_tensors.append((prediction_tensor, loss_tensor))
218 |
219 | # 3.) Predict
220 | losses, predictions, indices = [], [], []
221 | for i in range(0, num_steps, stepsize):
222 | (in_tg, out_tg) = tg.__getitem__(i)
223 | indices.append(tg.cv_indices[i])
224 | loss, prediction = [], []
225 | for o in range(0, len(out_tg)):
226 | evaluated = sess.run(calc_tensors[o], feed_dict={
227 | model.input: in_tg, real_tensor: out_tg[o]})
228 | prediction.append(evaluated[0][0, ...])
229 | loss.append(evaluated[1][0, ...]) # Get rid of batch dimensions
230 | predictions.append(prediction)
231 | losses.append(loss)
232 | if verbose > 0 and not i % 50:
233 | print('{} / {}'.format(i, num_steps), end='\r')
234 | if verbose > 0:
235 | print('Performed {} gradient steps'.format(num_steps // stepsize))
236 | losses, predictions, indices = np.array(
237 | losses), swap_listaxes(predictions), np.array(indices)
238 | tg.__dict__.update(tmp_dict)
239 |
240 | return losses, predictions, indices
241 |
242 |
243 | def shuffle_wavelets(training_generator, testing_generator, shuffles):
244 | """
245 | Shuffle procedure for model interpretation
246 |
247 | Parameters
248 | ----------
249 | training_generator : object
250 | Data generator for training data
251 | testing_generator : object
252 | Data generator for testing data
253 | shuffles : dict
254 | Indicates which axis to shuffle and which index in selected dimension, e.g. {'f' : 5} shuffles frequency axis 5
255 |
256 | Returns
257 | -------
258 | testing_generator : object
259 | Data generator for testing data with shuffled wavelets
260 | """
261 | rolled_wavelets = training_generator.wavelets.copy()
262 | for key, item in shuffles.items():
263 | if key == 'f':
264 | np.random.shuffle(rolled_wavelets[:, item, :]) # In place
265 | elif key == 'c':
266 | np.random.shuffle(rolled_wavelets[:, :, item]) # In place
267 | elif key == 't':
268 | np.random.shuffle(rolled_wavelets[item, :, :]) # In place
269 | testing_generator.wavelets = rolled_wavelets
270 | return testing_generator
271 |
272 |
273 | def swap_listaxes(list_in):
274 | list_out = []
275 | for o in range(0, len(list_in[0])):
276 | list_out.append(np.array([out[o] for out in list_in]))
277 | return list_out
278 |
279 |
280 | def calculate_model_stats(losses, predictions, indices, real, additional_metrics=[spearmanr]):
281 | """
282 | Calculates statistics on model predictions
283 |
284 | Parameters
285 | ----------
286 | fp_hdf_out : str
287 | File path to HDF5 file
288 | losses : (N,1) array_like
289 | Loss between predicted and ground truth observation
290 | predictions : dict
291 | Dictionary with predictions for each behaviour, each item in dict has size (N, Z) with Z the dimensions of the sample (e.g. Z_position=2, Z_speed=1, ...)
292 | indices : (N,1) array_like
293 | Indices which were evaluated, important when taking stepsize unequal to 1
294 | additional_metrics : list, optional
295 | Additional metrics besides Pearson and Model loss to be evaluated, should take arguments (y_true, y_pred) and return scalar or first argument as metric
296 |
297 | Returns
298 | -------
299 | df_scores
300 | Dataframe of evaluated scores
301 | """
302 | output_scores = []
303 | for idx, ((key, y_pred), (key2, y_true)) in enumerate(zip(predictions.items(), real.items())):
304 | pearson_mean, additional_mean = 0, np.zeros((len(additional_metrics)))
305 | for p in range(y_pred.shape[1]):
306 | pearson_mean += np.corrcoef(y_true[:, p], y_pred[:, p])[0, 1]
307 | for add_idx, am in enumerate(additional_metrics):
308 | am_eval = am(y_true[:, p], y_pred[:, p])
309 | if len(am_eval) > 1:
310 | am_eval = am_eval[0]
311 | additional_mean[add_idx] += am_eval
312 | additional_mean /= y_pred.shape[1]
313 | pearson_mean /= y_pred.shape[1]
314 | loss_mean = np.mean(losses[:, idx])
315 | output_scores.append((pearson_mean, loss_mean, *additional_mean))
316 | additional_columns = [f.__name__.title() for f in additional_metrics]
317 | df_scores = pd.DataFrame(output_scores, index=predictions.keys(), columns=['Pearson', 'Model Loss', *additional_columns])
318 |
319 | return df_scores
320 |
--------------------------------------------------------------------------------
/deepinsight/architecture.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | from tensorflow.keras.layers import Conv2D, GaussianNoise, TimeDistributed, Input, Dense, Lambda, Flatten, Dropout
8 | from tensorflow.keras.models import Model
9 | import tensorflow.keras.backend as K
10 |
11 |
12 | def the_decoder(tg, show_summary=True):
13 | """
14 | Model architecture used for decoding from wavelet transformed neural signals
15 |
16 | Parameters
17 | ----------
18 | tg : object
19 | Data generator, holding all important options for creating and training the model
20 | show_summary : bool, optional
21 | Whether to show a summary of the model after creation, by default True
22 |
23 | Returns
24 | -------
25 | model : object
26 | Keras model
27 | """
28 | model_input = Input(shape=tg.input_shape)
29 |
30 | x = GaussianNoise(tg.gaussian_noise)(model_input)
31 | # timestep reductions
32 | for nct in range(0, tg.num_convs_tsr):
33 | x = TimeDistributed(Conv2D(filters=tg.filter_size, kernel_size=(tg.kernel_size, tg.kernel_size), strides=(
34 | 2, 1), padding=tg.conv_padding, activation=tg.act_conv, name='conv_tsr{}'.format(nct)))(x)
35 | x = TimeDistributed(Conv2D(filters=tg.filter_size, kernel_size=(tg.kernel_size, tg.kernel_size), strides=(
36 | 1, 2), padding=tg.conv_padding, activation=tg.act_conv, name='conv_fr{}'.format(nct)))(x)
37 |
38 | # batch x 128 x 60 x 11
39 | x = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 3, 1, 4)))(x)
40 |
41 | layer_counter = 0
42 | while (K.int_shape(x)[3] > tg.channel_lower_limit):
43 | x = TimeDistributed(Conv2D(filters=tg.filter_size * 2, kernel_size=(1, 2), strides=(1, 2),
44 | padding=tg.conv_padding, activation=tg.act_conv, name='conv_after_tsr{}'.format(layer_counter)))(x)
45 | layer_counter += 1
46 |
47 | # Flatten and fc
48 | x_flat = TimeDistributed(Flatten())(x)
49 |
50 | outputs = []
51 | for (key, item), output in zip(tg.loss_functions.items(), tg.outputs):
52 | x = x_flat
53 | for d in range(0, tg.num_dense):
54 | x = Dense(tg.num_units_dense, activation=tg.act_fc, name='dense{}_combine{}'.format(d, key))(x)
55 | x = Dropout(tg.dropout_ratio)(x)
56 | out = Dense(output.shape[1], name='{}'.format(key), activation=tg.last_layer_activation_function)(x)
57 | outputs.append(out)
58 |
59 | model = Model(inputs=model_input, outputs=outputs)
60 |
61 | if show_summary:
62 | print(model.summary(line_length=100))
63 |
64 | return model
65 |
--------------------------------------------------------------------------------
/deepinsight/preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | import time
8 | from joblib import Parallel, delayed
9 | import numpy as np
10 | import h5py
11 | import tensorflow as tf # Progress bar only
12 | import deepinsight.util.wavelet_transform as wt
13 | from deepinsight.util import hdf5
14 |
15 |
16 | def preprocess_input(fp_hdf_out, raw_data, average_window=1000, channels=None, window_size=100000,
17 | gap_size=50000, sampling_rate=30000, scaling_factor=0.5, num_cores=4, **args):
18 | """
19 | Transforms raw neural data to frequency space, via wavelet transform implemented currently with aaren-wavelets (https://github.com/aaren/wavelets)
20 | Saves wavelet transformed data to HDF5 file (N, P, M) - (Number of timepoints, Number of frequencies, Number of channels)
21 |
22 | Parameters
23 | ----------
24 | fp_hdf_out : str
25 | File path to HDF5 file
26 | raw_data : (N, M) file or array_like
27 | Variable storing the raw_data (N data points, M channels), should allow indexing
28 | average_window : int, optional
29 | Average window to downsample wavelet transformed input, by default 1000
30 | channels : array_like, optional
31 | Which channels from raw_data to use, by default None
32 | window_size : int, optional
33 | Window size for calculating wavelet transformation, by default 100000
34 | gap_size : int, optional
35 | Gap size for calculating wavelet transformation, by default 50000
36 | sampling_rate : int, optional
37 | Sampling rate of raw_data, by default 30000
38 | scaling_factor : float, optional
39 | Determines amount of log-spaced frequencies P in output, by default 0.5
40 | num_cores : int, optional
41 | Number of paralell cores to use to calculate wavelet transformation, by default 4
42 | """
43 | # Get number of chunks
44 | if channels is None:
45 | channels = np.arange(0, raw_data.shape[1])
46 | num_points = raw_data.shape[0]
47 | if window_size > num_points:
48 | num_chunks = len(channels)
49 | output_size = raw_data.shape[0]
50 | mean_signal = np.mean(raw_data, axis=1)
51 | average_window = 1
52 | full_transform = True
53 | else:
54 | num_chunks = (num_points // gap_size) - 1
55 | output_size = ((num_chunks + 1) * gap_size) // average_window
56 | full_transform = False
57 |
58 | # Get estimate for number of frequencies
59 | (_, wavelet_frequencies) = wt.wavelet_transform(np.ones(window_size), sampling_rate, average_window, scaling_factor, **args)
60 | num_fourier_frequencies = len(wavelet_frequencies)
61 | # Prepare output file
62 | hdf5_file = h5py.File(fp_hdf_out, mode='a')
63 | if "inputs/wavelets" not in hdf5_file:
64 | hdf5_file.create_dataset("inputs/wavelets", [output_size, num_fourier_frequencies, len(channels)], np.float32)
65 | hdf5_file.create_dataset("inputs/fourier_frequencies", [num_fourier_frequencies], np.float16)
66 | # Makes saving 5 times faster as last index saving is fancy indexing and therefore slow
67 | hdf5_file.create_dataset("inputs/tmp_wavelets", [len(channels), output_size, num_fourier_frequencies], np.float32)
68 |
69 | # Prepare par pool
70 | par = Parallel(n_jobs=num_cores, verbose=0)
71 |
72 | # Start parallel wavelet transformation
73 | print('Starting wavelet transformation (n={}, chunks={}, frequencies={})'.format(
74 | num_points, num_chunks, num_fourier_frequencies))
75 | progress_bar = tf.keras.utils.Progbar(num_chunks, width=30, verbose=1, interval=0.05, unit_name='chunk')
76 | for c in range(0, num_chunks):
77 | if full_transform:
78 | raw_chunk = raw_data[:, c] - mean_signal
79 | else:
80 | start = gap_size * c
81 | end = start + window_size
82 | raw_chunk = raw_data[start: end, channels]
83 | # Process raw chunk
84 | raw_chunk = preprocess_chunk(raw_chunk, subtract_mean=True, convert_to_milivolt=False)
85 |
86 | # Calculate wavelet transform
87 | if full_transform:
88 | (wavelet_power, wavelet_frequencies) = wt.wavelet_transform(raw_chunk,
89 | sampling_rate=sampling_rate, scaling_factor=scaling_factor, average_window=average_window, **args)
90 | else:
91 | wavelet_transformed = np.zeros((raw_chunk.shape[0] // average_window, num_fourier_frequencies, raw_chunk.shape[1]))
92 | for ind, (wavelet_power, wavelet_frequencies) in enumerate(par(delayed(wt.wavelet_transform)(raw_chunk[:, i], sampling_rate, average_window, scaling_factor, **args) for i in range(0, raw_chunk.shape[1]))):
93 | wavelet_transformed[:, :, ind] = wavelet_power
94 |
95 | # Save in output file
96 | if full_transform:
97 | hdf5_file["inputs/tmp_wavelets"][c, :, :] = wavelet_power
98 | else:
99 | wavelet_index_end = end // average_window
100 | wavelet_index_start = start // average_window
101 | index_gap = gap_size // 2 // average_window
102 | if c == 0:
103 | this_index_start = 0
104 | this_index_end = wavelet_index_end - index_gap
105 | hdf5_file["inputs/wavelets"][this_index_start:this_index_end, :, :] = wavelet_transformed[0: -index_gap, :, :]
106 | elif c == num_chunks - 1: # Make sure the last one fits fully
107 | this_index_start = wavelet_index_start + index_gap
108 | this_index_end = wavelet_index_end
109 | hdf5_file["inputs/wavelets"][this_index_start:this_index_end, :, :] = wavelet_transformed[index_gap::, :, :]
110 | else:
111 | this_index_start = wavelet_index_start + index_gap
112 | this_index_end = wavelet_index_end - index_gap
113 | hdf5_file["inputs/wavelets"][this_index_start:this_index_end, :, :] = wavelet_transformed[index_gap: -index_gap, :, :]
114 | hdf5_file.flush()
115 | progress_bar.add(1)
116 |
117 | # 7.) Put frequencies in and close file
118 | if full_transform:
119 | wavelet_power = np.transpose(hdf5_file["inputs/tmp_wavelets"], axes=(1, 2, 0))
120 | del hdf5_file["inputs/tmp_wavelets"]
121 | hdf5_file["inputs/wavelets"][:] = wavelet_power
122 | hdf5_file["inputs/fourier_frequencies"][:] = wavelet_frequencies
123 | hdf5_file.flush()
124 | hdf5_file.close()
125 |
126 |
127 | def preprocess_chunk(raw_chunk, subtract_mean=True, convert_to_milivolt=False):
128 | """
129 | Preprocesses a chunk of data.
130 |
131 | Parameters
132 | ----------
133 | raw_chunk : array_like
134 | Chunk of raw_data to preprocess
135 | subtract_mean : bool, optional
136 | Subtract mean over all other channels, by default True
137 | convert_to_milivolt : bool, optional
138 | Convert chunk to milivolt , by default False
139 |
140 | Returns
141 | -------
142 | raw_chunk : array_like
143 | preprocessed_chunk
144 | """
145 | # Subtract mean across all channels
146 | if subtract_mean:
147 | raw_chunk = raw_chunk.transpose() - np.mean(raw_chunk.transpose(), axis=0)
148 | raw_chunk = raw_chunk.transpose()
149 | # Convert to milivolt
150 | if convert_to_milivolt:
151 | raw_chunk = raw_chunk * (0.195 / 1000)
152 | return raw_chunk
153 |
154 |
155 | def preprocess_output(fp_hdf_out, raw_timestamps, output, output_timestamps, average_window=1000, dataset_name='aligned', dataset_type=np.float16):
156 | """
157 | Base file for preprocessing outputs (handles M-D case as of March2020).
158 | For more complex cases use specialized functions (see for example preprocess_output in util.tetrode module)
159 |
160 | Parameters
161 | ----------
162 | fp_hdf_out : str
163 | File path to HDF5 file
164 | raw_timestamps : (N,1) array_like
165 | Timestamps for each sample in continous
166 | output : array_like
167 | M dimensional output which will be aligned with continous
168 | output_timestamps : (N,1) array_like
169 | Timestamps for output
170 | average_window : int, optional
171 | Downsampling factor for raw data and output, by default 1000
172 | dataset_name : str, optional
173 | Field name for output stored in HDF5 file
174 | """
175 | hdf5_file = h5py.File(fp_hdf_out, mode='a')
176 |
177 | # Get size of wavelets
178 | input_length = hdf5_file['inputs/wavelets'].shape[0]
179 |
180 | # Get positions of both LEDs
181 | raw_timestamps = raw_timestamps[()] # Slightly faster than np.array
182 | if output.ndim == 1:
183 | output = output[..., np.newaxis]
184 |
185 | output_aligned = np.array([np.interp(raw_timestamps[np.arange(0, raw_timestamps.shape[0],
186 | average_window)], output_timestamps, output[:, i]) for i in range(output.shape[1])]).transpose()
187 |
188 | # Create and save datasets in HDF5 File
189 | hdf5.create_or_update(hdf5_file, dataset_name="outputs/{}".format(dataset_name),
190 | dataset_shape=[input_length, output_aligned.shape[1]], dataset_type=dataset_type, dataset_value=output_aligned[0: input_length, ...])
191 | hdf5_file.flush()
192 | hdf5_file.close()
193 | print('Successfully written Dataset="{}" to {}'.format(dataset_name, fp_hdf_out))
194 |
--------------------------------------------------------------------------------
/deepinsight/train.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | import os
8 | import numpy as np
9 | import h5py
10 |
11 | from tensorflow.keras import optimizers
12 | from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau
13 |
14 | import tensorflow.keras.backend as K
15 |
16 | from . import architecture
17 | from . import util
18 |
19 |
20 | def train_model_on_generator(model, training_generator, testing_generator, loss_functions, loss_weights, steps_per_epoch=300, validation_steps=300, loss_metrics=[],
21 | epochs=10, tensorboard_logfolder='./', model_name='', verbose=1, reduce_lr=False, log_output=False, save_model_only=False, compile_only=False):
22 | """
23 | Function for training a given model, with data provided by training and testing generators
24 |
25 | Parameters
26 | ----------
27 | model : object
28 | Keras model
29 | training_generator : object
30 | Data generator for training data
31 | testing_generator : object
32 | Data generator for testing data
33 | loss_functions : dict
34 | Selected loss function for each behaviour
35 | loss_weights : dict
36 | Selected weights for each loss function
37 | steps_per_epoch : int, optional
38 | Number of steps for training the model, by default 300
39 | validation_steps : int, optional
40 | Number of steps for validating the model, by default 300
41 | epochs : int, optional
42 | Number of epochs to train model, by default 10
43 | tensorboard_logfolder : str, optional
44 | Where to store tensorboard logfiles, by default './'
45 | model_name : str, optional
46 | Name of selected model, used to return best model, by default ''
47 | verbose : int, optional
48 | Verbosity level, by default 1
49 | reduce_lr : bool, optional
50 | If True reduce learning rate on plateau, by default False
51 | log_output : bool, optional
52 | Log the output to tensorflow logfolder, by default False
53 | save_model_only : bool, optional
54 | Save best model after each epoch, by default False
55 | compile_only : bool, optional
56 | If true returns only compiled model, by default False
57 |
58 | Returns
59 | -------
60 | model : object
61 | Keras model
62 | history : dict
63 | Dictionary containing training and validation performance
64 | """
65 | # Compile model
66 | opt = optimizers.Adam(lr=training_generator.learning_rate, amsgrad=True)
67 | # Check if there are multiple outputs
68 | for key, item in loss_functions.items():
69 | try:
70 | function_handle = getattr(util.custom_losses, item)
71 | except (AttributeError, TypeError) as e:
72 | function_handle = item
73 | loss_functions[key] = function_handle
74 | model.compile(loss=loss_functions, optimizer=opt, loss_weights=loss_weights, metrics=loss_metrics)
75 | if compile_only: # What a hack. Keras bug from Oct9 in saving/loading models.
76 | return model
77 | # Get model name for storing tmp files
78 | if model_name is '':
79 | model_name = training_generator.get_name()
80 | # Initiate callbacks
81 | callbacks = []
82 | if reduce_lr:
83 | reduce_lr_cp = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, verbose=1)
84 | callbacks.append(reduce_lr_cp)
85 | if log_output:
86 | tensorboard_cp = TensorBoard(log_dir=tensorboard_logfolder)
87 | callbacks.append(tensorboard_cp)
88 | if save_model_only:
89 | file_name = model_name + '.hdf5'
90 | model_cp = ModelCheckpoint(filepath=file_name, save_best_only=True, save_weights_only=True)
91 | callbacks.append(model_cp)
92 | # Run model training
93 | try:
94 | history = model.fit(training_generator, steps_per_epoch=steps_per_epoch, epochs=epochs, shuffle=training_generator.shuffle,
95 | validation_steps=validation_steps, validation_data=testing_generator, verbose=verbose, callbacks=callbacks)
96 | except KeyboardInterrupt:
97 | print('-> Notebook interrupted')
98 | history = []
99 | finally:
100 | if save_model_only: # Make sure interruption of jupyter notebook returns best model
101 | model.load_weights(file_name)
102 | print('-> Returning best Model')
103 | return (model, history)
104 |
105 |
106 | def train_model(model_path, path_in, tensorboard_logfolder, model_tmp_path, loss_functions, loss_weights, user_opts, num_cvs=5, verbose=0):
107 | """
108 | Trains the model across the experiment using cross validation and saves the model files
109 | TODO Save models back to HDF5 to keep everything in one place
110 |
111 | Parameters
112 | ----------
113 | model_path : str
114 | Path to where model should be stored
115 | path_in : str
116 | Path to HDF5 File
117 | tensorboard_logfolder : str
118 | Path to where tensorboard logs should be stored
119 | model_tmp_path : str
120 | Temporary file path used for returning best fit model
121 | loss_functions : dict
122 | For each output the corresponding loss function
123 | loss_weights : dict
124 | For each output the corresponding weight
125 | user_opts : dict
126 | Model parameters in case default opts should be changed
127 | num_cvs : int, optional
128 | Number of cross validation splits, by default 5
129 | """
130 | # Get experiment length
131 | hdf5_file = h5py.File(path_in, mode='r')
132 | tmp_wavelets = hdf5_file['inputs/wavelets']
133 | tmp_opts = util.opts.get_opts(path_in, train_test_times=(np.array([]), np.array([])))
134 | # check for user options
135 | if user_opts is not None:
136 | for key, value in user_opts.items():
137 | tmp_opts[key] = value
138 | exp_indices = np.arange(0, tmp_wavelets.shape[0] - (tmp_opts['model_timesteps'] * tmp_opts['batch_size']))
139 | cv_splits = np.array_split(exp_indices, num_cvs)
140 | for cv_run, cvs in enumerate(cv_splits):
141 | K.clear_session()
142 | # For cv
143 | training_indices = np.setdiff1d(exp_indices, cvs) # All except the test indices
144 | testing_indices = cvs
145 | # opts -> generators -> model
146 | opts = util.opts.get_opts(path_in, train_test_times=(training_indices, testing_indices))
147 | opts['loss_functions'] = loss_functions.copy()
148 | opts['loss_weights'] = loss_weights
149 | opts['loss_names'] = list(loss_functions.keys())
150 | opts['num_cvs'] = num_cvs
151 | # check for user options
152 | if user_opts is not None:
153 | for key, value in user_opts.items():
154 | opts[key] = value
155 | (training_generator, testing_generator) = util.data_generator.create_train_and_test_generators(opts)
156 | model = get_model_from_function(training_generator, show_summary=False)
157 |
158 | print('------------------------------------------------')
159 | print('-> Model and generators loaded')
160 | print('------------------------------------------------')
161 | if verbose > 0:
162 | print(model.summary())
163 |
164 | (model, history) = train_model_on_generator(model, training_generator, testing_generator, loss_functions=loss_functions.copy(), loss_weights=loss_weights, reduce_lr=True,
165 | log_output=opts['log_output'], tensorboard_logfolder=tensorboard_logfolder, model_name=model_tmp_path, save_model_only=opts['save_model'],
166 | steps_per_epoch=opts['steps_per_epoch'], validation_steps=opts['validation_steps'], epochs=opts['epochs'], loss_metrics=opts['metrics'])
167 | # Save model and history
168 | if history:
169 | opts['history'] = history.history
170 | cv_model_path = model_path[0:-3] + '_' + str(cv_run) + '.h5'
171 | util.hdf5.save_model_with_opts(model, opts, cv_model_path)
172 | print('------------------------------------------------')
173 | print('-> Model_{} saved to {}'.format(cv_run, cv_model_path))
174 | print('------------------------------------------------')
175 | hdf5_file.close()
176 |
177 |
178 | def run_from_path(path_in, loss_functions, loss_weights, user_opts=None, **args):
179 | """
180 | Runs model training giving path to HDF5 file and loss dictionaries
181 |
182 | Parameters
183 | ----------
184 | path_in : str
185 | Path to HDF5
186 | loss_functions : dict
187 | For each output the corresponding loss function
188 | loss_weights : dict
189 | For each output the corresponding weight
190 | """
191 | dirname = os.path.dirname(path_in)
192 | filename = os.path.basename(path_in)
193 | # Define folders
194 | tensorboard_logfolder = dirname + '/logs/' + filename[0:-3] # Remove .h5 for logfolder
195 | model_tmp_path = dirname + '/models/tmp/tmp_model'
196 | model_path = dirname + '/models/' + filename[0:-3] + '_model.h5'
197 | # Create folders if needed
198 | for f in [os.path.dirname(model_tmp_path), os.path.dirname(model_path)]:
199 | if not os.path.exists(f):
200 | os.makedirs(f)
201 | print('------------------------------------------------')
202 | print('-> Running {} from {}'.format(filename, dirname))
203 | print('- Logs : {} \n- Model temporary : {} \n- Model : {}'.format(tensorboard_logfolder, model_tmp_path, model_path))
204 | print('------------------------------------------------')
205 | # Train model
206 | print('------------------------------------------------')
207 | print('Starting standard model')
208 | print('------------------------------------------------')
209 | train_model(model_path, path_in, tensorboard_logfolder, model_tmp_path, loss_functions, loss_weights, user_opts, **args)
210 |
211 |
212 | def get_model_from_function(training_generator, show_summary=True):
213 | model_function = getattr(architecture, training_generator.model_function)
214 | model = model_function(training_generator, show_summary=show_summary)
215 |
216 | return model
217 |
--------------------------------------------------------------------------------
/deepinsight/util/__init__.py:
--------------------------------------------------------------------------------
1 | from . import hdf5
2 | from . import tetrode
3 | from . import stats
4 | from . import custom_losses
5 | from . import wavelet_transform
6 | from . import opts
7 | from . import data_generator
8 |
--------------------------------------------------------------------------------
/deepinsight/util/custom_losses.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | from tensorflow.keras import backend as K
8 | import tensorflow as tf
9 | import numpy as np
10 |
11 |
12 | def euclidean_loss(y_true, y_pred):
13 | # We use tf.sqrt instead of K.sqrt as there is a bug in K.sqrt (as of March 14, 2018)
14 | res = tf.sqrt(K.sum(K.square(y_pred - y_true), axis=-1))
15 | return res
16 |
17 |
18 | def cyclical_mae_rad(y_true, y_pred):
19 | return K.mean(K.minimum(K.abs(y_pred - y_true), K.minimum(K.abs(y_pred - y_true + 2*np.pi), K.abs(y_pred - y_true - 2*np.pi))), axis=-1)
20 |
21 |
22 | def mse(y_true, y_pred):
23 | return tf.keras.losses.MSE(y_true, y_pred)
24 |
25 |
26 | def mae(y_true, y_pred):
27 | return tf.keras.losses.MAE(y_true, y_pred)
28 |
--------------------------------------------------------------------------------
/deepinsight/util/data_generator.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | import pickle
8 | import os
9 | import numpy as np
10 | from tensorflow.keras.utils import Sequence
11 |
12 | from . import hdf5
13 |
14 |
15 | def create_train_and_test_generators(opts):
16 | """
17 | Creates training and test generators given opts dictionary
18 |
19 | Parameters
20 | ----------
21 | opts : dict
22 | Dictionary holding options for data creation and model training
23 |
24 | Returns
25 | -------
26 | training_generator : object
27 | Sequence class used for generating training data
28 | testing_generator : object
29 | Sequence class used for generating testing data
30 | """
31 | # 1.) Create training generator
32 | training_generator = RawWaveletSequence(opts, training=True)
33 | # 2.) Create testing generator
34 | testing_generator = RawWaveletSequence(opts, training=False)
35 | # 3.) Assert that training and testing data are different
36 |
37 | return (training_generator, testing_generator)
38 |
39 |
40 | class RawWaveletSequence(Sequence):
41 | """
42 | Data Generator class. Import functions are get_input_sample and get_output_sample.
43 | Each call to __getitem__ will yield a (input, output) pair
44 |
45 | Parameters
46 | ----------
47 | Sequence : object
48 | Keras sequence
49 |
50 | Yields
51 | -------
52 | input_sample : array_like
53 | Batched input for model training
54 | output_sample : array_like
55 | Batched output for model optimization
56 | """
57 |
58 | def __init__(self, opts, training):
59 | # 1.) Set all options as attributes
60 | self.set_opts_as_attribute(opts)
61 |
62 | # 2.) Load data memmaped for mean/std estimation and fast plotting
63 | self.wavelets = hdf5.read_hdf_memmapped(self.fp_hdf_out, 'inputs/wavelets')
64 |
65 | # Get output(s)
66 | outputs = []
67 | for key, value in opts['loss_functions'].items():
68 | tmp_out = hdf5.read_hdf_memmapped(self.fp_hdf_out, 'outputs/' + key)
69 | outputs.append(tmp_out)
70 | self.outputs = outputs
71 |
72 | # 3.) Prepare for training
73 | self.training = training
74 | self.prepare_data_generator(training=training)
75 |
76 | def __len__(self):
77 | return len(self.cv_indices)
78 |
79 | def __getitem__(self, idx):
80 | # 1.) Define start and end index
81 | if self.shuffle:
82 | idx = np.random.choice(self.cv_indices)
83 | else:
84 | idx = self.cv_indices[idx]
85 | cut_range = np.arange(idx, idx + self.sample_size)
86 |
87 | # 2.) Above takes consecutive batches, implement random batching here
88 | if self.random_batches:
89 | indices = np.random.choice(self.cv_indices, size=self.batch_size)
90 | cut_range = [np.arange(start_index, start_index + self.model_timesteps) for start_index in indices]
91 | cut_range = np.array(cut_range)
92 | else:
93 | cut_range = np.reshape(cut_range, (self.batch_size, cut_range.shape[0] // self.batch_size))
94 |
95 | # 3.) Get input sample
96 | input_sample = self.get_input_sample(cut_range)
97 |
98 | # 4.) Get output sample
99 | output_sample = self.get_output_sample(cut_range)
100 |
101 | return (input_sample, output_sample)
102 |
103 | def get_input_sample(self, cut_range):
104 | # 1.) Cut Ephys / fancy indexing for memmap is planned, if fixed use: cut_data = self.wavelets[cut_range, self.fourier_frequencies, self.channels]
105 | cut_data = self.wavelets[cut_range, :, :]
106 | cut_data = np.reshape(cut_data, (cut_data.shape[0] * cut_data.shape[1], cut_data.shape[2], cut_data.shape[3]))
107 |
108 | # 2.) Normalize input
109 | cut_data = (cut_data - self.est_mean) / self.est_std
110 |
111 | # 3.) Reshape for model input
112 | cut_data = np.reshape(cut_data, (self.batch_size, self.model_timesteps, cut_data.shape[1], cut_data.shape[2]))
113 |
114 | # 4.) Take care of optional settings
115 | cut_data = np.transpose(cut_data, axes=(0, 3, 1, 2))
116 | cut_data = cut_data[..., np.newaxis]
117 |
118 | return cut_data
119 |
120 | def get_output_sample(self, cut_range):
121 | # 1.) Cut Ephys
122 | out_sample = []
123 | for out in self.outputs:
124 | cut_data = out[cut_range, ...]
125 | cut_data = np.reshape(cut_data, (cut_data.shape[0] * cut_data.shape[1], cut_data.shape[2]))
126 |
127 | # 2.) Reshape for model output
128 | if len(cut_data.shape) is not self.batch_size:
129 | cut_data = np.reshape(cut_data, (self.batch_size, self.model_timesteps, cut_data.shape[1]))
130 |
131 | # 3.) Divide evenly and make sure last output is being decoded
132 | if self.average_output:
133 | cut_data = cut_data[:, np.arange(0, cut_data.shape[1] + 1, self.average_output)[1::] - 1]
134 | out_sample.append(cut_data)
135 |
136 | return out_sample
137 |
138 | def prepare_data_generator(self, training):
139 | # 1.) Define sample size and means
140 | self.sample_size = self.model_timesteps * self.batch_size
141 |
142 | if training:
143 | self.cv_indices = self.training_indices
144 | else:
145 | self.cv_indices = self.testing_indices
146 |
147 | # Make sure random choice takes from array not list 500x speedup
148 | self.cv_indices = np.array(self.cv_indices)
149 |
150 | # 9.) Calculate normalization for wavelets
151 | meanstd_path = os.path.dirname(self.fp_hdf_out) + '/models/tmp/' + os.path.basename(self.fp_hdf_out)[:-3] + '_meanstd_start{}_end{}_tstart{}_tend{}.p'.format(
152 | self.training_indices[0], self.training_indices[-1], self.testing_indices[0], self.testing_indices[-1])
153 |
154 | if os.path.exists(meanstd_path):
155 | (self.est_mean, self.est_std) = pickle.load(open(meanstd_path, 'rb'))
156 | else:
157 | print('Calculating MAD normalization parameters')
158 | if len(self.training_indices) > 1e5:
159 | print('Downsampling wavelets for MAD calculation')
160 | self.est_mean = np.median(self.wavelets[self.training_indices[::100], :, :], axis=0)
161 | self.est_std = np.median(abs(self.wavelets[self.training_indices[::100], :, :] - self.est_mean), axis=0)
162 | else:
163 | self.est_mean = np.median(self.wavelets[self.training_indices, :, :], axis=0)
164 | self.est_std = np.median(abs(self.wavelets[self.training_indices, :, :] - self.est_mean), axis=0)
165 | pickle.dump((self.est_mean, self.est_std), open(meanstd_path, 'wb'))
166 |
167 | # Make sure indices contain no NaN values
168 | if self.handle_nan:
169 | self.cv_indices = self.check_for_nan()
170 |
171 | # 10.) Define output shape. Most robust way is to get a dummy input and take that shape as output shape
172 | (dummy_input, dummy_output) = self.__getitem__(0)
173 | # Corresponds to the output of this generator, aka input to model. Also remove batch shape,
174 | self.input_shape = dummy_input.shape[1:]
175 |
176 | def set_opts_as_attribute(self, opts):
177 | for k, v in opts.items():
178 | setattr(self, k, v)
179 |
180 | def get_name(self):
181 | name = ""
182 | for attr in self.important_attributes:
183 | name += attr + ':{},'.format(getattr(self, attr))
184 | return name[:-1]
185 |
186 | def check_for_nan(self):
187 | new_cv_indices, len_before = [], len(self.cv_indices)
188 | for idx, cv in enumerate(self.cv_indices):
189 | if not idx % 100000:
190 | print('{} / {}'.format(cv, self.cv_indices[-1]), end='\r')
191 | cut_range = np.arange(cv, cv + self.sample_size)
192 | cut_range = np.reshape(cut_range, (self.batch_size, cut_range.shape[0] // self.batch_size))
193 | out_sample = self.get_output_sample(cut_range)
194 | nan_in_out = any([any(np.isnan(x.flatten())) for x in out_sample[0]])
195 | if not nan_in_out:
196 | new_cv_indices.append(cv)
197 | print('Len before {}, after {} --- Diff {}'.format(len_before, len(new_cv_indices), len_before - len(new_cv_indices)))
198 | return np.array(new_cv_indices)
199 |
--------------------------------------------------------------------------------
/deepinsight/util/hdf5.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | import h5py
8 | import numpy as np
9 | from . import data_generator
10 |
11 |
12 | def create_or_update(hdf5_file, dataset_name, dataset_shape, dataset_type, dataset_value):
13 | """
14 | Create or update dataset in HDF5 file
15 |
16 | Parameters
17 | ----------
18 | hdf5_file : File
19 | File identifier
20 | dataset_name : str
21 | Name of new dataset
22 | dataset_shape : array_like
23 | Shape of new dataset
24 | dataset_type : type
25 | Type of dataset (np.float16, np.float32, 'S', etc...)
26 | dataset_value : array_like
27 | Data to store in HDF5 file
28 | """
29 | if not dataset_name in hdf5_file:
30 | hdf5_file.create_dataset(dataset_name, dataset_shape, dataset_type)
31 | hdf5_file[dataset_name][:] = dataset_value
32 | else:
33 | if hdf5_file[dataset_name].shape != dataset_shape:
34 | del hdf5_file[dataset_name]
35 | hdf5_file.create_dataset(dataset_name, dataset_shape, dataset_type)
36 | hdf5_file[dataset_name][:] = dataset_value
37 | hdf5_file.flush()
38 |
39 |
40 | def save_model_with_opts(model, opts, file_name):
41 | """
42 | Saves Keras model and training options to HDF5 file
43 | Uses Keras save_weights for creating the model HDF5 file and then inserts into that
44 |
45 | Parameters
46 | ----------
47 | model : object
48 | Keras model
49 | opts : dict
50 | Dictionary used for training the model
51 | file_name : str
52 | Path to save to
53 | """
54 | model.save_weights(file_name)
55 | hdf5_file = h5py.File(file_name, mode='a')
56 | hdf5_file['opts'] = str(opts)
57 | hdf5_file.flush()
58 | hdf5_file.close()
59 |
60 |
61 | def load_model_with_opts(file_name):
62 | """
63 | Load Keras model and training options from HDF5 file
64 | TODO: Remove eval and find better way of storing dict in HDF5 (hickle, pytables, etc...)
65 |
66 | Parameters
67 | ----------
68 | file_name : str
69 | Model path
70 |
71 | Returns
72 | -------
73 | model : object
74 | Keras model
75 | training_generator : object
76 | Datagenerator used to create training samples on the fly
77 | testing_generator : object
78 | Datagenerator used to create testing samples on the fly
79 | opts : dict
80 | Dictionary used for training the model
81 | """
82 | from .. import train
83 | # Get options from dictionary, stored as str in HDF5 (not recommended, TODO)
84 | hdf5_file = h5py.File(file_name, mode='r')
85 | opts = eval(hdf5_file['opts'][()])
86 | opts['handle_nan'] = False
87 | hdf5_file.close()
88 |
89 | # Use options to create data generators and model weights
90 | (training_generator, testing_generator) = data_generator.create_train_and_test_generators(opts)
91 |
92 | model = train.get_model_from_function(training_generator, show_summary=False)
93 | model = train.train_model_on_generator(model, training_generator, testing_generator,
94 | loss_functions=opts['loss_functions'], loss_weights=opts['loss_weights'], compile_only=True)
95 | model.load_weights(file_name)
96 |
97 | return (model, training_generator, testing_generator, opts)
98 |
99 |
100 | def read_hdf_memmapped(fn_hdf, hdf_group):
101 | """
102 | Reads the hdf file as a numpy memmapped file, makes slicing a bit faster
103 | (From https://gist.github.com/rossant/7b4704e8caeb8f173084)
104 |
105 | Parameters
106 | ----------
107 | fn_hdf : str
108 | Path to preprocessed HDF5
109 | hdf_group : str
110 | Group to read from HDF5
111 |
112 | Returns
113 | -------
114 | data : array_like
115 | Data as a memory mapped array
116 | """
117 | # Define function for memmapping
118 | def _mmap_h5(path, h5path):
119 | with h5py.File(path, mode='r') as f:
120 | ds = f[h5path]
121 | # We get the dataset address in the HDF5 fiel.
122 | offset = ds.id.get_offset()
123 | # We ensure we have a non-compressed contiguous array.
124 | assert ds.chunks is None
125 | assert ds.compression is None
126 | assert offset > 0
127 | dtype = ds.dtype
128 | shape = ds.shape
129 | arr = np.memmap(path, mode='r', shape=shape,
130 | offset=offset, dtype=dtype)
131 | return arr
132 | # Load data
133 | data = _mmap_h5(fn_hdf, hdf_group)
134 |
135 | return data
136 |
--------------------------------------------------------------------------------
/deepinsight/util/opts.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 |
8 |
9 | def get_opts(fp_hdf_out, train_test_times):
10 | """
11 | Returns the options dictionary which contains all parameters needed to
12 | create DataGenerator and train the model
13 | TODO Find better method of parameter storing (config files, store in HDF5, etc...)
14 |
15 | Parameters
16 | ----------
17 | fp_hdf_out : str
18 | File path to HDF5 file
19 | train_test_times : array_like
20 | Indices for training and testing generator
21 |
22 | Returns
23 | -------
24 | opts : dict
25 | Dictionary containing all model and training parameters
26 | """
27 | opts = dict()
28 | # -------- DATA ------------------------
29 | opts['fp_hdf_out'] = fp_hdf_out # Filepath for hdf5 file storing wavelets and outputs
30 | opts['sampling_rate'] = 30 # Sampling rate of the wavelets
31 | opts['training_indices'] = train_test_times[0].tolist() # Indices into wavelets used for training the model, adjusted during CV
32 | opts['testing_indices'] = train_test_times[1].tolist() # Indices into wavelets used for testing the model, adjusted during CV
33 |
34 | # -------- MODEL PARAMETERS --------------
35 | opts['model_function'] = 'the_decoder' # Model architecture used
36 | opts['model_timesteps'] = 64 # How many timesteps are used in the input layer, e.g. a sampling rate of 30 will yield 2.13s windows. Has to be divisible X times by 2. X='num_convs_tsr'
37 | opts['num_convs_tsr'] = 4 # Number of downsampling steps within the model, e.g. with model_timesteps=64, it will downsample 64->32->16->8->4 and output 4 timesteps
38 | opts['average_output'] = 2**opts['num_convs_tsr'] # Whats the ratio between input and output shape
39 | opts['channel_lower_limit'] = 2
40 |
41 | opts['optimizer'] = 'adam' # Learning algorithm
42 | opts['learning_rate'] = 0.0007 # Learning rate
43 | opts['kernel_size'] = 3 # Kernel size for all convolutional layers
44 | opts['conv_padding'] = 'same' # Which padding should be used for the convolutional layers
45 | opts['act_conv'] = 'elu' # Activation function for convolutional layers
46 | opts['act_fc'] = 'elu' # Activation function for fully connected layers
47 | opts['dropout_ratio'] = 0 # Dropout ratio for fully connected layers
48 | opts['filter_size'] = 64 # Number of filters in convolutional layers
49 | opts['num_units_dense'] = 1024 # Number of units in fully connected layer
50 | opts['num_dense'] = 2 # Number of fully connected layers
51 | opts['gaussian_noise'] = 1 # How much gaussian noise is added (unit = standard deviation)
52 |
53 | # -------- TRAINING----------------------
54 | opts['batch_size'] = 8 # Batch size used for training the model
55 | opts['steps_per_epoch'] = 250 # Number of steps per training epoch
56 | opts['validation_steps'] = 250 # Number of steps per validation epoch
57 | opts['epochs'] = 20 # Number of epochs
58 | opts['shuffle'] = True # If input should be shuffled
59 | opts['random_batches'] = True # If random batches in time are used
60 | opts['metrics'] = []
61 | opts['last_layer_activation_function'] = 'linear'
62 | opts['handle_nan'] = False
63 |
64 | # -------- MISC--------------- ------------
65 | opts['tensorboard_logfolder'] = './' # Logfolder for tensorboard
66 | opts['model_folder'] = './' # Folder for saving the model
67 | opts['log_output'] = False # If output should be logged
68 | opts['save_model'] = False # If model should be saved
69 |
70 | return opts
71 |
--------------------------------------------------------------------------------
/deepinsight/util/stats.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | import numpy as np
8 |
9 |
10 | def calculate_speed_from_position(positions, interval, smoothing=False):
11 | """
12 | Calculate speed from X,Y coordinates
13 |
14 | Parameters
15 | ----------
16 | positions : (N, 2) array_like
17 | N samples of observations, containing X and Y coordinates
18 | interval : int
19 | Duration between observations (in s, equal to 1 / sr)
20 | smoothing : bool or int, optional
21 | If speeds should be smoothed, by default False/0
22 |
23 | Returns
24 | -------
25 | speed : (N, 1) array_like
26 | Instantenous speed of the animal
27 | """
28 | X, Y = positions[:, 0], positions[:, 1]
29 | # Smooth diffs instead of speeds directly
30 | Xdiff = np.diff(X)
31 | Ydiff = np.diff(Y)
32 | if smoothing:
33 | Xdiff = smooth_signal(Xdiff, smoothing)
34 | Ydiff = smooth_signal(Ydiff, smoothing)
35 | speed = np.sqrt(Xdiff**2 + Ydiff**2) / interval
36 | speed = np.append(speed, speed[-1])
37 |
38 | return speed
39 |
40 |
41 | def calculate_heading_direction_from_position(positions, smoothing=False, return_as_deg=False):
42 | """
43 | Calculates heading direction based on X and Y coordinates. With one measurement we can only calculate heading direction
44 |
45 | Parameters
46 | ----------
47 | positions : (N, 2) array_like
48 | N samples of observations, containing X and Y coordinates
49 | smoothing : bool or int, optional
50 | If speeds should be smoothed, by default False/0
51 | return_as_deg : bool
52 | Return heading in radians or degree
53 |
54 | Returns
55 | -------
56 | heading_direction : (N, 1) array_like
57 | Heading direction of the animal
58 | """
59 | X, Y = positions[:, 0], positions[:, 1]
60 | # Smooth diffs instead of speeds directly
61 | Xdiff = np.diff(X)
62 | Ydiff = np.diff(Y)
63 | if smoothing:
64 | Xdiff = smooth_signal(Xdiff, smoothing)
65 | Ydiff = smooth_signal(Ydiff, smoothing)
66 | # Calculate heading direction
67 | heading_direction = np.arctan2(Ydiff, Xdiff)
68 | heading_direction = np.append(heading_direction, heading_direction[-1])
69 | if return_as_deg:
70 | heading_direction = heading_direction * (180 / np.pi)
71 |
72 | return heading_direction
73 |
74 |
75 | def calculate_head_direction_from_leds(positions, return_as_deg=False):
76 | """
77 | Calculates head direction based on X and Y coordinates with two LEDs.
78 |
79 | Parameters
80 | ----------
81 | positions : (N, 2) array_like
82 | N samples of observations, containing X and Y coordinates
83 | return_as_deg : bool
84 | Return heading in radians or degree
85 |
86 | Returns
87 | -------
88 | head_direction : (N, 1) array_like
89 | Head direction of the animal
90 | """
91 | X_led1, Y_led1, X_led2, Y_led2 = positions[:, 0], positions[:, 1], positions[:, 2], positions[:, 3]
92 | # Calculate head direction
93 | head_direction = np.arctan2(X_led1 - X_led2, Y_led1 - Y_led2)
94 | # Put in right perspective in relation to the environment
95 | offset = +np.pi/2
96 | head_direction = (head_direction + offset + np.pi) % (2*np.pi) - np.pi
97 | head_direction *= -1
98 | if return_as_deg:
99 | head_direction = head_direction * (180 / np.pi)
100 |
101 | return head_direction
102 |
103 |
104 | def smooth_signal(signal, N):
105 | """
106 | Simple smoothing by convolving a filter with 1/N.
107 |
108 | Parameters
109 | ----------
110 | signal : array_like
111 | Signal to be smoothed
112 | N : int
113 | smoothing_factor
114 |
115 | Returns
116 | -------
117 | signal : array_like
118 | Smoothed signal
119 | """
120 | # Preprocess edges
121 | signal = np.concatenate([signal[0:N], signal, signal[-N:]])
122 | # Convolve
123 | signal = np.convolve(signal, np.ones((N,))/N, mode='same')
124 | # Postprocess edges
125 | signal = signal[N:-N]
126 |
127 | return signal
128 |
--------------------------------------------------------------------------------
/deepinsight/util/tetrode.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | import numpy as np
8 | import pandas as pd
9 | import h5py
10 |
11 | from . import hdf5
12 | from . import stats
13 |
14 |
15 | def read_open_ephys(fp_raw_file):
16 | """
17 | Reads ST open ephys files
18 |
19 | Parameters
20 | ----------
21 | fp_raw_file : str
22 | File path to open ephys file
23 |
24 | Returns
25 | -------
26 | continouos : (N,M) array_like
27 | Continous ephys with N timepoints and M channels
28 | timestamps : (N,1) array_like
29 | Timestamps for each sample in continous
30 | positions : (N,5) array_like
31 | Position of animal with two LEDs and timestamps
32 | info : object
33 | Additional information about experiments
34 | """
35 | fid_ephys = h5py.File(fp_raw_file, mode='r')
36 |
37 | # Load timestamps and continuous data, python 3 keys() returns view
38 | recording_key = list(fid_ephys['acquisition']['timeseries'].keys())[0]
39 | processor_key = list(fid_ephys['acquisition']['timeseries'][recording_key]['continuous'].keys())[0]
40 |
41 | # Load raw ephys and timestamps
42 | # not converted to microvolts, need to multiply by 0.195. We don't multiply here as we cant load full array into memory
43 | continuous = fid_ephys['acquisition']['timeseries'][recording_key]['continuous'][processor_key]['data']
44 | timestamps = fid_ephys['acquisition']['timeseries'][recording_key]['continuous'][processor_key]['timestamps']
45 |
46 | # We can also read position directly from the raw file
47 | positions = fid_ephys['acquisition']['timeseries'][recording_key]['tracking']['ProcessedPos']
48 |
49 | # Read general settings
50 | info = fid_ephys['general']['data_collection']['Settings']
51 |
52 | return (continuous, timestamps, positions, info)
53 |
54 |
55 | def read_tetrode_data(fp_raw_file):
56 | """
57 | Read ST data from openEphys recording system
58 |
59 | Parameters
60 | ----------
61 | fp_raw_file : str
62 | File path to open ephys file
63 |
64 | Returns
65 | -------
66 | raw_data : (N,M) array_like
67 | Continous ephys with N timepoints and M channels
68 | raw_timestamps : (N,1) array_like
69 | Timestamps for each sample in continous
70 | output : (N,4) array_like
71 | Position of animal with two LEDs
72 | output_timestamps : (N,1) array_like
73 | Timestamps for positions
74 | info : object
75 | Additional information about experiments
76 | """
77 | (raw_data, raw_timestamps, positions, info) = read_open_ephys(fp_raw_file)
78 | output_timestamps = positions[:, 0]
79 | output = positions[:, 1:5]
80 | bad_channels = info['General']['badChan']
81 | bad_channels = [int(n) for n in bad_channels[()].decode('UTF-8').split(',')]
82 | good_channels = np.delete(np.arange(0, 128), bad_channels)
83 | info = {'channels': good_channels, 'bad_channels': bad_channels, 'sampling_rate': 30000}
84 |
85 | return (raw_data, raw_timestamps, output, output_timestamps, info)
86 |
87 |
88 | def preprocess_output(fp_hdf_out, raw_timestamps, output, output_timestamps, average_window=1000, sampling_rate=30000):
89 | """
90 | Write behaviours to decode into HDF5 file
91 |
92 | Parameters
93 | ----------
94 | fp_hdf_out : str
95 | File path to HDF5 file
96 | raw_timestamps : (N,1) array_like
97 | Timestamps for each sample in continous
98 | output : (N,4) array_like
99 | Position of animal with two LEDs
100 | output_timestamps : (N,1) array_like
101 | Timestamps for positions
102 | average_window : int, optional
103 | Downsampling factor for raw data and positions, by default 1000
104 | sampling_rate : int, optional
105 | Sampling rate of raw ephys, by default 30000
106 | """
107 | hdf5_file = h5py.File(fp_hdf_out, mode='a')
108 |
109 | # Get size of wavelets
110 | input_length = hdf5_file['inputs/wavelets'].shape[0]
111 |
112 | # Get positions of both LEDs
113 | raw_timestamps = raw_timestamps[()] # Slightly faster than np.array
114 | output_x_led1 = np.interp(raw_timestamps[np.arange(0, raw_timestamps.shape[0],
115 | average_window)], output_timestamps, output[:, 0])
116 | output_y_led1 = np.interp(raw_timestamps[np.arange(0, raw_timestamps.shape[0],
117 | average_window)], output_timestamps, output[:, 1])
118 | output_x_led2 = np.interp(raw_timestamps[np.arange(0, raw_timestamps.shape[0],
119 | average_window)], output_timestamps, output[:, 2])
120 | output_y_led2 = np.interp(raw_timestamps[np.arange(0, raw_timestamps.shape[0],
121 | average_window)], output_timestamps, output[:, 3])
122 | raw_positions = np.array([output_x_led1, output_y_led1, output_x_led2, output_y_led2]).transpose()
123 |
124 | # Clean raw_positions and get centre
125 | positions_smooth = pd.DataFrame(raw_positions.copy()).interpolate(
126 | limit_direction='both').rolling(5, min_periods=1).mean().values
127 | position = np.array([(positions_smooth[:, 0] + positions_smooth[:, 2]) / 2,
128 | (positions_smooth[:, 1] + positions_smooth[:, 3]) / 2]).transpose()
129 |
130 | # Also get head direction and speed from positions
131 | speed = stats.calculate_speed_from_position(position, interval=1/(sampling_rate//average_window), smoothing=3)
132 | head_direction = stats.calculate_head_direction_from_leds(positions_smooth, return_as_deg=False)
133 |
134 | # Create and save datasets in HDF5 File
135 | hdf5.create_or_update(hdf5_file, dataset_name="outputs/raw_position",
136 | dataset_shape=[input_length, 4], dataset_type=np.float16, dataset_value=raw_positions[0: input_length, :])
137 | hdf5.create_or_update(hdf5_file, dataset_name="outputs/position",
138 | dataset_shape=[input_length, 2], dataset_type=np.float16, dataset_value=position[0: input_length, :])
139 | hdf5.create_or_update(hdf5_file, dataset_name="outputs/head_direction", dataset_shape=[
140 | input_length, 1], dataset_type=np.float16, dataset_value=head_direction[0: input_length, np.newaxis])
141 | hdf5.create_or_update(hdf5_file, dataset_name="outputs/speed",
142 | dataset_shape=[input_length, 1], dataset_type=np.float16, dataset_value=speed[0: input_length, np.newaxis])
143 | hdf5_file.flush()
144 | hdf5_file.close()
145 |
--------------------------------------------------------------------------------
/deepinsight/util/wavelet_transform.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | from wavelets import WaveletAnalysis
8 | import numpy as np
9 |
10 |
11 | def wavelet_transform(signal, sampling_rate, average_window=1000, scaling_factor=0.25, wave_highpass=2, wave_lowpass=30000):
12 | """
13 | Calculates the wavelet transform for each point in signal, then averages
14 | each window and returns together fourier frequencies
15 |
16 | Parameters
17 | ----------
18 | signal : (N,1) array_like
19 | Signal to be transformed
20 | sampling_rate : int
21 | Sampling rate of signal
22 | average_window : int, optional
23 | Average window to downsample wavelet transformed input, by default 1000
24 | scaling_factor : float, optional
25 | Determines amount of log-spaced frequencies M in output, by default 0.25
26 | wave_highpass : int, optional
27 | Cut of frequencies below, by default 2
28 | wave_lowpass : int, optional
29 | Cut of frequencies above, by default 30000
30 |
31 | Returns
32 | -------
33 | wavelet_power : (N, M) array_like
34 | Wavelet transformed signal
35 | wavelet_frequencies : (M, 1) array_like
36 | Corresponding frequencies to wavelet_power
37 | """
38 | (wavelet_power, wavelet_frequencies, wavelet_obj) = simple_wavelet_transform(signal, sampling_rate,
39 | scaling_factor=scaling_factor, wave_highpass=wave_highpass, wave_lowpass=wave_lowpass)
40 |
41 | # Average over window
42 | if average_window is not 1:
43 | wavelet_power = np.reshape(
44 | wavelet_power, (wavelet_power.shape[0], wavelet_power.shape[1] // average_window, average_window))
45 | wavelet_power = np.mean(wavelet_power, axis=2).transpose()
46 | else:
47 | wavelet_power = wavelet_power.transpose()
48 |
49 | return wavelet_power, wavelet_frequencies
50 |
51 |
52 | def simple_wavelet_transform(signal, sampling_rate, scaling_factor=0.25, wave_lowpass=None, wave_highpass=None):
53 | """
54 | Simple wavelet transformation of signal
55 |
56 | Parameters
57 | ----------
58 | signal : (N,1) array_like
59 | Signal to be transformed
60 | sampling_rate : int
61 | Sampling rate of signal
62 | scaling_factor : float, optional
63 | Determines amount of log-space frequencies M in output, by default 0.25
64 | wave_highpass : int, optional
65 | Cut of frequencies below, by default 2
66 | wave_lowpass : int, optional
67 | Cut of frequencies above, by default 30000
68 |
69 | Returns
70 | -------
71 | wavelet_power : (N, M) array_like
72 | Wavelet transformed signal
73 | wavelet_frequencies : (M, 1) array_like
74 | Corresponding frequencies to wavelet_power
75 | wavelet_obj : object
76 | WaveletTransform Object
77 | """
78 | wavelet_obj = WaveletAnalysis(signal, dt=1 / sampling_rate, dj=scaling_factor)
79 | wavelet_power = wavelet_obj.wavelet_power
80 | wavelet_frequencies = wavelet_obj.fourier_frequencies
81 |
82 | if wave_lowpass or wave_highpass:
83 | wavelet_power = wavelet_power[(wavelet_frequencies < wave_lowpass) & (wavelet_frequencies > wave_highpass), :]
84 | wavelet_frequencies = wavelet_frequencies[(wavelet_frequencies < wave_lowpass) & (wavelet_frequencies > wave_highpass)]
85 |
86 | return (wavelet_power, wavelet_frequencies, wavelet_obj)
87 |
--------------------------------------------------------------------------------
/deepinsight/visualize.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | import numpy as np
8 | import h5py
9 | import pandas as pd
10 | import matplotlib.pyplot as plt
11 | import seaborn as sns
12 | sns.set_style('white')
13 |
14 |
15 | def plot_residuals(fp_hdf_out, output_names, losses=None, shuffled_losses=None, aggregator=np.mean, frequency_spacing=1, offset=0):
16 | """
17 | Plots influence plots for each output
18 |
19 | Parameters
20 | ----------
21 | fp_hdf_out : str
22 | File path to HDF5 file
23 | aggregator : function handle, optional
24 | Which aggregator to use for plotting the lineplots, by default np.mean
25 | frequency_spacing : int, optional
26 | Spacing on x axis between frequencies, by default 1
27 | """
28 | # Read data from HDF5 file
29 | hdf5_file = h5py.File(fp_hdf_out, mode='r')
30 | if losses is None:
31 | losses = hdf5_file["analysis/losses"][()]
32 | if shuffled_losses is None:
33 | shuffled_losses = hdf5_file["analysis/influence/shuffled_losses"][()]
34 | frequencies = hdf5_file["inputs/fourier_frequencies"][()].astype(np.float32)
35 | hdf5_file.close()
36 |
37 | # Calculate residuals, make sure there is no division by zero by adding small constant. TODO Should be relative to loss and only if needed
38 | residuals = (shuffled_losses - losses) / (losses + offset)
39 |
40 | # Plot
41 | fig, axes = plt.subplots(len(output_names), 1, figsize=(16, 8))
42 | if len(output_names) > 1:
43 | axes = axes.flatten()
44 | else:
45 | axes = [axes]
46 | for all_residuals, ax, on in zip(residuals.transpose(), axes, output_names):
47 | residuals_mean = np.mean(all_residuals, axis=0)
48 | all_residuals = all_residuals / np.sum(residuals_mean)
49 | df_to_plot = pd.DataFrame(all_residuals).melt()
50 | sns.lineplot(x="variable", y="value", data=df_to_plot, ax=ax, estimator=aggregator, ci=68, marker='o',
51 | color='k').set(xlabel='Frequencies (Hz)', ylabel='Frequency Influence (%)')
52 | ax.set_xticks(np.arange(0, len(frequencies), frequency_spacing))
53 | ax.set_xticklabels(np.round(frequencies[0::frequency_spacing], 2), fontsize=8, rotation=45)
54 | ax.set_title(on)
55 | for ax in axes:
56 | ax.invert_xaxis()
57 | sns.despine()
58 | fig.tight_layout()
59 | fig.show()
60 |
--------------------------------------------------------------------------------
/media/colab_walkthrough.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CYHSM/DeepInsight/e5a66be5dc3c671c37bd30ddf8f1f8ebae78ed2c/media/colab_walkthrough.gif
--------------------------------------------------------------------------------
/media/decoding_error.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CYHSM/DeepInsight/e5a66be5dc3c671c37bd30ddf8f1f8ebae78ed2c/media/decoding_error.gif
--------------------------------------------------------------------------------
/media/model_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CYHSM/DeepInsight/e5a66be5dc3c671c37bd30ddf8f1f8ebae78ed2c/media/model_architecture.png
--------------------------------------------------------------------------------
/notebooks/deepinsight_calcium_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "pn9vUyWfZTX6",
8 | "slideshow": {
9 | "slide_type": "slide"
10 | }
11 | },
12 | "source": [
13 | "---\n",
14 | "# **Introduction to DeepInsight - Decoding position from two-photon calcium recordings**\n",
15 | "---\n",
16 | "\n",
17 | "This notebook stands as an example of how to use DeepInsight on calcium data and can be used as a guide on how to adapt it to your own datasets. All methods are stored in the deepinsight library and can be called directly or in their respective submodules. A typical workflow might look like the following: \n",
18 | "\n",
19 | "- Load your dataset into a format which can be directly indexed (numpy array or pointer to a file on disk)\n",
20 | "- Preprocess the raw data (wavelet transformation)\n",
21 | "- Preprocess your outputs (the variable you want to decode)\n",
22 | "- Define appropriate loss functions for your output and train the model \n",
23 | "- Predict performance across all cross validated models\n",
24 | "- Visualize influence of different input frequencies on model output\n",
25 | "\n",
26 | "We use the calcium dataset here as it has lower sampling rate and is therefore faster to preprocess and train, which makes it suitable to also run the preprocessing in a Colab notebook.\n"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "colab_type": "text",
33 | "id": "9iwZvplEoO70",
34 | "slideshow": {
35 | "slide_type": "subslide"
36 | }
37 | },
38 | "source": [
39 | "---\n",
40 | "## **Install and import DeepInsight**\n",
41 | "---\n",
42 | "Make sure you are using a **GPU runtime** if you want to train your own models. Go to Runtime -> Change Runtime type to change from CPU to GPU.\n",
43 | "You can check the GPU which is used in Colab by running !nvidia-smi in a new cell "
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {
50 | "colab": {
51 | "base_uri": "https://localhost:8080/",
52 | "height": 51
53 | },
54 | "colab_type": "code",
55 | "id": "Uguw1SjlZLRX",
56 | "outputId": "fa22ce71-f3ff-4ff7-ac55-491d7011f2e0",
57 | "slideshow": {
58 | "slide_type": "subslide"
59 | }
60 | },
61 | "outputs": [],
62 | "source": [
63 | "# Import DeepInsight\n",
64 | "import sys\n",
65 | "sys.path.insert(0, \"/home/marx/Documents/Github/DeepInsight\")\n",
66 | "import deepinsight\n",
67 | "\n",
68 | "# Other imports\n",
69 | "import os\n",
70 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
71 | "import h5py\n",
72 | "import numpy as np\n",
73 | "import pandas as pd\n",
74 | "from scipy.io import loadmat\n",
75 | "import plotly.graph_objs as go\n",
76 | "from skimage import io\n",
77 | "\n",
78 | "# Initialize plotly figures\n",
79 | "from plotly.offline import init_notebook_mode \n",
80 | "init_notebook_mode(connected = True)\n",
81 | "\n",
82 | "# Make sure the output width is adjusted for better export as HTML\n",
83 | "from IPython.core.display import display, HTML\n",
84 | "display(HTML(\"\"))\n",
85 | "display(HTML(\"\"))"
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {
91 | "colab_type": "text",
92 | "id": "INZc18eRZTYL",
93 | "slideshow": {
94 | "slide_type": "slide"
95 | }
96 | },
97 | "source": [
98 | "---\n",
99 | "## **Load and preprocess your data**\n",
100 | "---\n",
101 | "For this example we provide two-photon calcium imaging data from a mouse in a virtual environment. Calcium traces together with the variable of interest is stored in one .mat file. You can load it from whatever datasource you want, just make sure that the dimensions match. \n",
102 | "\n",
103 | "The input to the model in the form of (Timepoints x Number of Cells) is stored in `raw_data`\n",
104 | "\n",
105 | "The output to be decoded is in the form of (Timepoints x 1) and is stored in `output` together with timestamps for the output in `raw_timestamps`\n",
106 | "\n",
107 | "---\n",
108 | "\n",
109 | "**Run the next cells if you want to load the example data and preprocess it. You can also skip to 'Preprocess Data' to just load the preprocessed file and directly train the model.** "
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {
116 | "colab": {
117 | "base_uri": "https://localhost:8080/",
118 | "height": 289
119 | },
120 | "colab_type": "code",
121 | "id": "Q1GNMR41iqBf",
122 | "outputId": "42008674-03b7-4837-dbe8-372a89a36e03",
123 | "slideshow": {
124 | "slide_type": "subslide"
125 | }
126 | },
127 | "outputs": [],
128 | "source": [
129 | "base_path = './example_data/calcium/'\n",
130 | "fp_raw_file = base_path + 'traces_M1336.mat'\n",
131 | "if not os.path.exists(base_path):\n",
132 | " os.makedirs(base_path)\n",
133 | "if not os.path.exists(fp_raw_file): # Careful as next command is a colab command where parameters had to be hard coded. Keep in mind if changing fp_raw_file\n",
134 | " !wget https://ndownloader.figshare.com/files/24024683 -O ./example_data/calcium/traces_M1336.mat"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "metadata": {
141 | "colab": {
142 | "base_uri": "https://localhost:8080/",
143 | "height": 34
144 | },
145 | "colab_type": "code",
146 | "id": "uCT9WQ56wTia",
147 | "outputId": "3a530bab-7aee-4675-b062-ba656c7f2c8b",
148 | "slideshow": {
149 | "slide_type": "subslide"
150 | }
151 | },
152 | "outputs": [],
153 | "source": [
154 | "# Set base variables\n",
155 | "sampling_rate = 30 # Might also be stored in above mat file for easier access\n",
156 | "channels = np.arange(0, 100) # For this recording channels corresponds to cells. We only use the first 100 cells to speed up preprocessing (Change this if you run it on your own dataset)\n",
157 | "\n",
158 | "# Also define Paths to access downloaded files\n",
159 | "base_path = './example_data/calcium/'\n",
160 | "fp_raw_file = base_path + 'traces_M1336.mat' # This is an example dataset containing calcium traces and linear position in a virtual track\n",
161 | "fp_deepinsight = base_path + 'processed_M1336.h5' # This will be the processed HDF5 file\n",
162 | "\n",
163 | "# Load data from mat file\n",
164 | "calcium_data = loadmat(fp_raw_file)['dataSave']\n",
165 | "raw_data = np.squeeze(calcium_data['df_f'][0][0])\n",
166 | "raw_timestamps = np.arange(0, raw_data.shape[0]) / sampling_rate\n",
167 | "output = np.squeeze(calcium_data['pos_dat'][0][0])\n",
168 | "\n",
169 | "print('Data loaded. Calcium traces: {}, Decoding target {}'.format(raw_data.shape, output.shape))"
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {
175 | "colab_type": "text",
176 | "id": "fNbPYTuf79de",
177 | "slideshow": {
178 | "slide_type": "slide"
179 | }
180 | },
181 | "source": [
182 | "---\n",
183 | "### Plot example calcium traces\n",
184 | "---\n",
185 | "To give a visual impression of the input to our model we can now plot calcium traces for a bunch of different cells. "
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": null,
191 | "metadata": {
192 | "colab": {
193 | "base_uri": "https://localhost:8080/",
194 | "height": 542
195 | },
196 | "colab_type": "code",
197 | "id": "s7LG4uj8ywFw",
198 | "outputId": "05677e37-cfd3-4ea9-ff8a-f6c1107f5330",
199 | "slideshow": {
200 | "slide_type": "subslide"
201 | }
202 | },
203 | "outputs": [],
204 | "source": [
205 | "end_point, y_offset, num_cells = 10000, 400, 6\n",
206 | "fig = go.Figure()\n",
207 | "for i in range(0, num_cells):\n",
208 | " fig.add_trace(go.Scatter(x=np.arange(0, end_point) / sampling_rate, y=raw_data[0:end_point, i] + (i * y_offset), line=dict(color='rgba(0, 0, 0, 0.85)', width=2), name='Cell {}'.format(i+1)))\n",
209 | "# aesthetics\n",
210 | "fig.update_yaxes(visible=False)\n",
211 | "fig.update_layout(showlegend=False,plot_bgcolor=\"white\",width=1800, height=650,margin=dict(t=20,l=20,b=20,r=20),xaxis_title='Time (s)', font=dict(family='Open Sans', size=16, color='black'))\n",
212 | "fig.show()"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {
218 | "colab_type": "text",
219 | "id": "9ubJSXsC_yh9",
220 | "slideshow": {
221 | "slide_type": "slide"
222 | }
223 | },
224 | "source": [
225 | "---\n",
226 | "### Preprocess data \n",
227 | "---"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": null,
233 | "metadata": {
234 | "colab": {
235 | "base_uri": "https://localhost:8080/",
236 | "height": 156
237 | },
238 | "colab_type": "code",
239 | "id": "O61lbp1TZTYM",
240 | "outputId": "6aad0856-0424-4954-e985-1c3f8f672ff7",
241 | "slideshow": {
242 | "slide_type": "subslide"
243 | }
244 | },
245 | "outputs": [],
246 | "source": [
247 | "if not os.path.exists(fp_deepinsight):\n",
248 | " if os.path.exists(fp_raw_file): # Only do this if user downloaded raw files otherwise download preprocessed hdf5 file\n",
249 | " # Process output for use as decoding target\n",
250 | " # As the mouse is running on a virtual linear track we have a circular variable. We can solve this by either:\n",
251 | " # (1) Using a circular loss function or \n",
252 | " # (2) Using the sin and cos of the variable\n",
253 | " # For this dataset we choose method (2), see the loss calculation for head directionality on CA1 recordings for an example of (1)\n",
254 | " output = (output - np.nanmin(output)) / (np.nanmax(output) - np.nanmin(output))\n",
255 | " output = (output * 2*np.pi) - np.pi # Scaled to -pi / pi\n",
256 | " output = np.squeeze(np.column_stack([np.sin(output), np.cos(output)]))\n",
257 | " output = pd.DataFrame(output).ffill().bfill().values # Get rid of NaNs\n",
258 | " output_timestamps = raw_timestamps # In this recording timestamps are the same for output and raw_data, meaning they are already aligned to each other\n",
259 | "\n",
260 | " # Transform raw data to frequency domain\n",
261 | " # We use a small cutoff (1/500) for the low frequencies to keep the dimensions low & the model training fast\n",
262 | " deepinsight.preprocess.preprocess_input(fp_deepinsight, raw_data, sampling_rate=sampling_rate, average_window=1, wave_highpass=1/500, wave_lowpass=sampling_rate, channels=channels) \n",
263 | " # # Prepare outputs\n",
264 | " deepinsight.preprocess.preprocess_output(fp_deepinsight, raw_timestamps, output, output_timestamps, average_window=1, dataset_name='sin_cos')\n",
265 | " else:\n",
266 | " if not os.path.exists(base_path):\n",
267 | " os.makedirs(base_path)\n",
268 | " if not os.path.exists(fp_deepinsight):\n",
269 | " !wget https://ndownloader.figshare.com/files/23658674 -O ./example_data/calcium/processed_M1336.h5"
270 | ]
271 | },
272 | {
273 | "cell_type": "markdown",
274 | "metadata": {
275 | "colab_type": "text",
276 | "id": "PEjBzYQYo9d9",
277 | "slideshow": {
278 | "slide_type": "slide"
279 | }
280 | },
281 | "source": [
282 | "---\n",
283 | "### Plot preprocessed data\n",
284 | "---\n",
285 | "We plot examples to double check the wavelet preprocessing. Each plot shows the wavelet processed calcium traces for one cell"
286 | ]
287 | },
288 | {
289 | "cell_type": "code",
290 | "execution_count": null,
291 | "metadata": {
292 | "colab": {},
293 | "colab_type": "code",
294 | "id": "xeS4h0_5k54q",
295 | "slideshow": {
296 | "slide_type": "skip"
297 | }
298 | },
299 | "outputs": [],
300 | "source": [
301 | "hdf5_file = h5py.File(fp_deepinsight, mode='r')\n",
302 | "wavelets = hdf5_file['inputs/wavelets']\n",
303 | "frequencies = np.round(hdf5_file['inputs/fourier_frequencies'], 3)"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": null,
309 | "metadata": {
310 | "colab": {
311 | "base_uri": "https://localhost:8080/",
312 | "height": 542
313 | },
314 | "colab_type": "code",
315 | "id": "3Isfo5s5CIhO",
316 | "outputId": "122a960f-05fa-4831-f442-c65daf4a92f3",
317 | "slideshow": {
318 | "slide_type": "subslide"
319 | }
320 | },
321 | "outputs": [],
322 | "source": [
323 | "num_cells, gap = 20, 30\n",
324 | "fig = go.Figure()\n",
325 | "for i in range(0, num_cells):\n",
326 | " this_z = wavelets[0:wavelets.shape[0]//2:gap,:,i].transpose()\n",
327 | " fig.add_heatmap(x=np.arange(0, this_z.shape[0]) / (sampling_rate / gap), z=this_z,colorscale='Viridis',visible=False,showscale=False)\n",
328 | "fig.data[0].visible = True\n",
329 | "# aesthetics\n",
330 | "steps = []\n",
331 | "for i in range(len(fig.data)):\n",
332 | " step = dict(method=\"update\",label=\"Cell {}\".format(i+1),args=[{\"visible\": [False] * len(fig.data)}])\n",
333 | " step[\"args\"][0][\"visible\"][i] = True # Toggle i'th trace to \"visible\"\n",
334 | " steps.append(step)\n",
335 | "sliders = [dict(active=10,currentvalue={\"prefix\": \"Cell: \", \"visible\" : False},pad={\"t\": 70},steps=steps)]\n",
336 | "\n",
337 | "fig.update_layout(width=1800, height=650,sliders=sliders, yaxis = dict(tickvals=np.arange(0, len(frequencies)), ticktext = ['{:.3f}'.format(i) for i in frequencies], autorange='reversed'), yaxis_title='Frequency (Hz)',\n",
338 | " showlegend=False, plot_bgcolor=\"white\",margin=dict(t=20,l=20,b=20,r=20),xaxis_title='Time (s)', font=dict(family='Open Sans', size=16, color='black'))\n",
339 | "fig"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": null,
345 | "metadata": {
346 | "colab": {},
347 | "colab_type": "code",
348 | "id": "yIXVttGdEmzd",
349 | "slideshow": {
350 | "slide_type": "skip"
351 | }
352 | },
353 | "outputs": [],
354 | "source": [
355 | "hdf5_file.close()"
356 | ]
357 | },
358 | {
359 | "cell_type": "markdown",
360 | "metadata": {
361 | "colab_type": "text",
362 | "id": "R2hXQTtuZTYX",
363 | "slideshow": {
364 | "slide_type": "slide"
365 | }
366 | },
367 | "source": [
368 | "---\n",
369 | "## **Train the model**\n",
370 | "---\n",
371 | "The following command uses 5 cross validations to train the models and stores weights in HDF5 files"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": null,
377 | "metadata": {
378 | "colab": {
379 | "base_uri": "https://localhost:8080/",
380 | "height": 1000
381 | },
382 | "colab_type": "code",
383 | "id": "z3wSllkHZTYY",
384 | "outputId": "c6715e26-35b4-40bf-9623-1628a02600ae",
385 | "slideshow": {
386 | "slide_type": "subslide"
387 | }
388 | },
389 | "outputs": [],
390 | "source": [
391 | "# Define loss functions and train model, if more then one behaviour/stimuli needs to be decoded, define loss functions and weights for each of them here\n",
392 | "loss_functions = {'sin_cos' : 'mse'}\n",
393 | "loss_weights = {'sin_cos' : 1} \n",
394 | "user_opts = {'epochs' : 10, 'sample_per_epoch' : 250} # Speed up for Colab, normally set to {'epochs' : 20, 'sample_per_epoch' : 250\n",
395 | "deepinsight.train.run_from_path(fp_deepinsight, loss_functions, loss_weights, user_opts=user_opts)"
396 | ]
397 | },
398 | {
399 | "cell_type": "markdown",
400 | "metadata": {
401 | "colab_type": "text",
402 | "id": "p6ybugPYcuRe",
403 | "slideshow": {
404 | "slide_type": "slide"
405 | }
406 | },
407 | "source": [
408 | "---\n",
409 | "## **Evaluate model performance**\n",
410 | "---\n",
411 | "Here we calculate the losses over the whole duration of the experiment. Step size indicates how many timesteps are skipped between samples. Note that each sample contains 64 timesteps, so setting step size to 64 will result in non-overlapping samples"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": null,
417 | "metadata": {
418 | "colab": {
419 | "base_uri": "https://localhost:8080/",
420 | "height": 156
421 | },
422 | "colab_type": "code",
423 | "id": "ZBrHYchnVVck",
424 | "outputId": "d310f7e3-e760-4382-87cc-4965aea6dbcc",
425 | "slideshow": {
426 | "slide_type": "subslide"
427 | }
428 | },
429 | "outputs": [],
430 | "source": [
431 | "step_size = 100\n",
432 | "\n",
433 | "# Get loss and shuffled loss for influence plot, both is also stored back to HDF5 file\n",
434 | "losses, output_predictions, indices = deepinsight.analyse.get_model_loss(fp_deepinsight, stepsize=step_size)\n",
435 | "\n",
436 | "# Get real output from HDF5 file\n",
437 | "hdf5_file = h5py.File(fp_deepinsight, mode='r')\n",
438 | "output_real = hdf5_file['outputs/sin_cos'][indices,:]"
439 | ]
440 | },
441 | {
442 | "cell_type": "markdown",
443 | "metadata": {
444 | "colab_type": "text",
445 | "id": "2WUiAPxPdw6c",
446 | "slideshow": {
447 | "slide_type": "slide"
448 | }
449 | },
450 | "source": [
451 | "---\n",
452 | "### Visualize model performance\n",
453 | "---\n",
454 | "We plot the real output vs. the predicted output for the above trained models. The real output is linearized as in the virtual reality environment the start follows after the mouse reaches the end, therefore we can use a circular variable. Also note that the example plot below is only trained on a subset of channels (see channels variable, default=100) and a limited number of epochs (see epochs, default=5), to make training in the Colab notebook faster. The performance on the fully evaluated dataset is higher. "
455 | ]
456 | },
457 | {
458 | "cell_type": "code",
459 | "execution_count": null,
460 | "metadata": {
461 | "colab": {
462 | "base_uri": "https://localhost:8080/",
463 | "height": 542
464 | },
465 | "colab_type": "code",
466 | "id": "TrimxQpIX20O",
467 | "outputId": "c831ca7a-aa76-47b0-ed45-80fb788b8c40",
468 | "slideshow": {
469 | "slide_type": "subslide"
470 | }
471 | },
472 | "outputs": [],
473 | "source": [
474 | "fig = go.Figure()\n",
475 | "\n",
476 | "fig.add_trace(go.Scatter(x=np.arange(0, output_real.shape[0]) / (sampling_rate / step_size), y=output_real[:,0], line=dict(color='rgba(0, 0, 0, 0.85)', width=2), name='Real'))\n",
477 | "fig.add_trace(go.Scatter(x=np.arange(0, output_real.shape[0]) / (sampling_rate / step_size), y=output_predictions['sin_cos'][:,0], line=dict(color='rgb(67, 116, 144)', width=3), name='Predicted'))\n",
478 | "\n",
479 | "# aesthetics\n",
480 | "#fig.update_yaxes(visible=False)\n",
481 | "fig.update_layout(width=1800, height=650, plot_bgcolor=\"rgb(245, 245, 245)\",margin=dict(t=20,l=20,b=20,r=20),xaxis_title='Time (s)', yaxis_title='Decoding target (sin)', font=dict(family='Open Sans', size=16, color='black'))\n",
482 | "fig"
483 | ]
484 | },
485 | {
486 | "cell_type": "code",
487 | "execution_count": null,
488 | "metadata": {
489 | "colab": {},
490 | "colab_type": "code",
491 | "id": "jPkRVpHPX25i",
492 | "slideshow": {
493 | "slide_type": "skip"
494 | }
495 | },
496 | "outputs": [],
497 | "source": [
498 | "hdf5_file.close()"
499 | ]
500 | },
501 | {
502 | "cell_type": "markdown",
503 | "metadata": {
504 | "colab_type": "text",
505 | "id": "inqoPEb4eCmu",
506 | "slideshow": {
507 | "slide_type": "skip"
508 | }
509 | },
510 | "source": [
511 | "---\n",
512 | "### Get shuffled model performance\n",
513 | "---\n",
514 | "We use the shuffled loss to evaluate feature importance"
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "execution_count": null,
520 | "metadata": {
521 | "colab": {
522 | "base_uri": "https://localhost:8080/",
523 | "height": 122
524 | },
525 | "colab_type": "code",
526 | "id": "LM9SKb4FZTYc",
527 | "outputId": "066e9263-c598-428f-f811-3e9169ce9c6e",
528 | "scrolled": true,
529 | "slideshow": {
530 | "slide_type": "skip"
531 | }
532 | },
533 | "outputs": [],
534 | "source": [
535 | "shuffled_losses_ax1 = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=1, stepsize=step_size)"
536 | ]
537 | },
538 | {
539 | "cell_type": "code",
540 | "execution_count": null,
541 | "metadata": {
542 | "colab": {},
543 | "colab_type": "code",
544 | "id": "dCVfQGTxpm05",
545 | "slideshow": {
546 | "slide_type": "skip"
547 | }
548 | },
549 | "outputs": [],
550 | "source": [
551 | "# Calculate residuals, make sure there is no division by zero by adding small constant.\n",
552 | "residuals = (shuffled_losses_ax1 - losses) / (losses + 0.1)\n",
553 | "residuals_mean = np.mean(residuals, axis=1)[:,0]\n",
554 | "residuals_standarderror = np.std(residuals, axis=1)[:,0] / np.sqrt(residuals.shape[0])"
555 | ]
556 | },
557 | {
558 | "cell_type": "markdown",
559 | "metadata": {
560 | "colab_type": "text",
561 | "id": "1bSLxr8RJvnN",
562 | "slideshow": {
563 | "slide_type": "slide"
564 | }
565 | },
566 | "source": [
567 | "---\n",
568 | "### Show feature importance for frequency axis\n",
569 | "---\n",
570 | "This plot shows the relative influence of each frequency band on the decoding of the position in the virtual environment. We plot the mean across samples + the standard error for each frequency band. "
571 | ]
572 | },
573 | {
574 | "cell_type": "code",
575 | "execution_count": null,
576 | "metadata": {
577 | "colab": {
578 | "base_uri": "https://localhost:8080/",
579 | "height": 542
580 | },
581 | "colab_type": "code",
582 | "id": "BCeTJQykprE7",
583 | "outputId": "cd340d3f-82d8-49b1-d77d-ec2bc2b93683",
584 | "slideshow": {
585 | "slide_type": "subslide"
586 | }
587 | },
588 | "outputs": [],
589 | "source": [
590 | "end_point, y_offset, num_cells = 1000, 400, 6\n",
591 | "fig = go.Figure()\n",
592 | "\n",
593 | "fig.add_trace(go.Scatter(x=np.arange(0, residuals_mean.shape[0]), y=residuals_mean, line=dict(color='rgba(0, 0, 0, 0.85)', width=3), name='Real',\n",
594 | " error_y=dict(type='data', array=residuals_standarderror, visible=True, color='rgb(67, 116, 144)', thickness=3)))\n",
595 | "\n",
596 | "# aesthetics\n",
597 | "#fig.update_yaxes(visible=False)\n",
598 | "fig.update_layout(width=1800, height=650, plot_bgcolor=\"rgb(245, 245, 245)\",margin=dict(t=20,l=20,b=20,r=20), xaxis = dict(tickvals=np.arange(0, len(frequencies)), ticktext = ['{:.3f}'.format(i) for i in frequencies], autorange='reversed'),\n",
599 | " xaxis_title='Frequency (Hz)', yaxis_title='Relative influence', font=dict(family='Open Sans', size=16, color='black',\n",
600 | "))\n",
601 | "fig"
602 | ]
603 | },
604 | {
605 | "cell_type": "markdown",
606 | "metadata": {
607 | "colab_type": "text",
608 | "id": "xW1fimfsJ2Hl",
609 | "slideshow": {
610 | "slide_type": "slide"
611 | }
612 | },
613 | "source": [
614 | "---\n",
615 | "### Show feature importance for cell axis\n",
616 | "---\n",
617 | "For this we shuffle across the cell dimension to see the influence each cell has on the decoding of position and then plot it back to the calcium ROIs. In the plot below the size of the dot is indicating the relative influence of this ROI (cell) on the decoding performance. Red dots indicate a high influence of this cell on the decoding of position and blue dots indicate a negative influence of this cell.\n"
618 | ]
619 | },
620 | {
621 | "cell_type": "code",
622 | "execution_count": null,
623 | "metadata": {
624 | "colab": {
625 | "base_uri": "https://localhost:8080/",
626 | "height": 122
627 | },
628 | "colab_type": "code",
629 | "id": "34Cnoa_fpq9z",
630 | "outputId": "3989f4dc-20d7-4261-fcc1-fee2084058f2",
631 | "slideshow": {
632 | "slide_type": "skip"
633 | }
634 | },
635 | "outputs": [],
636 | "source": [
637 | "shuffled_losses = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=2, stepsize=step_size)"
638 | ]
639 | },
640 | {
641 | "cell_type": "code",
642 | "execution_count": null,
643 | "metadata": {
644 | "colab": {},
645 | "colab_type": "code",
646 | "id": "bh0ZvGpEKE4r",
647 | "slideshow": {
648 | "slide_type": "skip"
649 | }
650 | },
651 | "outputs": [],
652 | "source": [
653 | "# Calculate residuals, make sure there is no division by zero by adding small constant.\n",
654 | "residuals = (shuffled_losses - losses) / (losses + 0.1)\n",
655 | "residuals_mean = np.mean(residuals, axis=1)[:,0]"
656 | ]
657 | },
658 | {
659 | "cell_type": "code",
660 | "execution_count": null,
661 | "metadata": {
662 | "colab": {
663 | "base_uri": "https://localhost:8080/",
664 | "height": 785
665 | },
666 | "colab_type": "code",
667 | "id": "AYtjoLGaF9Ml",
668 | "outputId": "1dab7de7-43dd-46f3-f6ae-b06600045e42",
669 | "slideshow": {
670 | "slide_type": "skip"
671 | }
672 | },
673 | "outputs": [],
674 | "source": [
675 | "# Get some files for plotting the importance of each cell back to brain anatomy\n",
676 | "if not os.path.exists('./example_data/calcium/centroid_YX.mat'):\n",
677 | " !wget https://www.dropbox.com/s/z8ynet2nkt9pe1u/centroid_YX.mat -O ./example_data/calcium/centroid_YX.mat\n",
678 | "if not os.path.exists('./example_data/calcium/calcium_rois.jpg'): \n",
679 | " !wget https://www.dropbox.com/s/czak7rphajslcr0/test_rois_F5.jpg -O ./example_data/calcium/calcium_rois.jpg\n",
680 | "roi_data = loadmat('./example_data/calcium/centroid_YX.mat')['xy_coords']"
681 | ]
682 | },
683 | {
684 | "cell_type": "code",
685 | "execution_count": null,
686 | "metadata": {
687 | "colab": {
688 | "base_uri": "https://localhost:8080/",
689 | "height": 542
690 | },
691 | "colab_type": "code",
692 | "id": "nMYWGuPb6AWK",
693 | "outputId": "a1e353e7-ad0b-4a9e-ffda-23bb74a952fb",
694 | "slideshow": {
695 | "slide_type": "slide"
696 | }
697 | },
698 | "outputs": [],
699 | "source": [
700 | "fig = go.Figure()\n",
701 | "point_size_adjustment = 1250\n",
702 | "all_pos = residuals_mean > 0\n",
703 | "all_pos_channels = channels[all_pos]\n",
704 | "all_neg_channels = channels[~all_pos]\n",
705 | "fig.add_trace(go.Image(z=io.imread('./example_data/calcium/calcium_rois.jpg')))\n",
706 | "fig.add_trace(go.Scatter(x=roi_data[:,1], y=roi_data[:,0], marker_symbol='circle', mode='markers', marker=dict(color='white', opacity=0.5, line=dict(color='white',width=0)), name='Cell centers'))\n",
707 | "fig.add_trace(go.Scatter(x=roi_data[all_pos_channels,1], y=roi_data[all_pos_channels,0], marker_symbol='circle', mode='markers', marker=dict(color='red', size=residuals_mean[all_pos]*point_size_adjustment, opacity=0.5, line=dict(color='black',width=3)), name='Pos. influence'))\n",
708 | "fig.add_trace(go.Scatter(x=roi_data[all_neg_channels,1], y=roi_data[all_neg_channels,0], marker_symbol='circle', mode='markers', marker=dict(color='blue', size=residuals_mean[~all_pos]*-point_size_adjustment, opacity=0.5, line=dict(color='black',width=3)), name='Neg. influence'))\n",
709 | "\n",
710 | "fig.update_layout(width=1800, height=650, showlegend=False, plot_bgcolor=\"white\",margin=dict(t=10,l=0,b=10,r=0), xaxis=dict(showticklabels=False), yaxis=dict(showticklabels=False))\n",
711 | "fig"
712 | ]
713 | },
714 | {
715 | "cell_type": "code",
716 | "execution_count": null,
717 | "metadata": {
718 | "colab": {},
719 | "colab_type": "code",
720 | "id": "xR73Csfr-U83",
721 | "slideshow": {
722 | "slide_type": "skip"
723 | }
724 | },
725 | "outputs": [],
726 | "source": []
727 | }
728 | ],
729 | "metadata": {
730 | "accelerator": "GPU",
731 | "celltoolbar": "Slideshow",
732 | "colab": {
733 | "collapsed_sections": [],
734 | "name": "deepinsight_calcium_example.ipynb",
735 | "provenance": [],
736 | "toc_visible": true
737 | },
738 | "kernelspec": {
739 | "display_name": "Python 3.7.6 64-bit",
740 | "language": "python",
741 | "name": "python37664bit5fa017aec819437bacf63081b14c694c"
742 | },
743 | "language_info": {
744 | "codemirror_mode": {
745 | "name": "ipython",
746 | "version": 3
747 | },
748 | "file_extension": ".py",
749 | "mimetype": "text/x-python",
750 | "name": "python",
751 | "nbconvert_exporter": "python",
752 | "pygments_lexer": "ipython3",
753 | "version": "3.7.10"
754 | }
755 | },
756 | "nbformat": 4,
757 | "nbformat_minor": 1
758 | }
759 |
--------------------------------------------------------------------------------
/notebooks/example_data/calcium/calcium_rois.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CYHSM/DeepInsight/e5a66be5dc3c671c37bd30ddf8f1f8ebae78ed2c/notebooks/example_data/calcium/calcium_rois.jpg
--------------------------------------------------------------------------------
/notebooks/static/ephys_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "---\n",
8 | "---\n",
9 | "\n",
10 | "# Introduction to DeepInsight - Decoding position, speed and head direction from tetrode CA1 recordings\n",
11 | "\n",
12 | "This notebook stands as an example of how to use DeepInsight v0.5 on tetrode data and can be used as a guide on how to adapt it to your own datasets. All methods are stored in the deepinsight library and can be called directly or in their respective submodules. A typical workflow might look like the following: \n",
13 | "- Load your dataset into a format which can be directly indexed (numpy array or pointer to a file on disk)\n",
14 | "- Preprocess the raw data (wavelet transformation)\n",
15 | "- Preprocess your outputs (the variable you want to decode)\n",
16 | "- Define appropriate loss functions for your output and train the model \n",
17 | "- Predict performance across all cross validated models\n",
18 | "- Visualize influence of different input frequencies on model output\n"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "# Import DeepInsight\n",
28 | "import sys\n",
29 | "sys.path.insert(0, \"/home/marx/Documents/Github/DeepInsight\")\n",
30 | "import deepinsight\n",
31 | "# Choose GPU\n",
32 | "import os\n",
33 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "---\n",
41 | "---\n",
42 | "Here you can define the paths to your raw data files, and create file names for the preprocessed HDF5 datasets.\n",
43 | "\n",
44 | "The data we use here is usually relatively large in its raw format. Running it through the next lines takes roughly 24 hours for a 40 minute recording.\n",
45 | "\n",
46 | "We provide a preprocess file to play with the code. See next cell"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "# Define base paths\n",
56 | "base_path = './example_data/'\n",
57 | "fp_raw_file = base_path + 'experiment_1.nwb' # This is your raw file\n",
58 | "fp_deepinsight = base_path + 'processed_R2478.h5' # This will be the processed HDF5 file\n",
59 | "\n",
60 | "if os.path.exists(fp_raw_file):\n",
61 | " # Load data \n",
62 | " (raw_data,\n",
63 | " raw_timestamps,\n",
64 | " output,\n",
65 | " output_timestamps,\n",
66 | " info) = deepinsight.util.tetrode.read_tetrode_data(fp_raw_file)\n",
67 | " # Transform raw data to frequency domain\n",
68 | " deepinsight.preprocess.preprocess_input(fp_deepinsight, raw_data, sampling_rate=info['sampling_rate'],\n",
69 | " channels=info['channels'])\n",
70 | " # Prepare outputs\n",
71 | " deepinsight.util.tetrode.preprocess_output(fp_deepinsight, raw_timestamps, output,\n",
72 | " output_timestamps, sampling_rate=info['sampling_rate'])"
73 | ]
74 | },
75 | {
76 | "cell_type": "markdown",
77 | "metadata": {},
78 | "source": [
79 | "---\n",
80 | "---\n",
81 | "The above steps create a HDF5 file with all important data for training the model.\n",
82 | "\n",
83 | "You can download the preprocessed dataset by running the following command"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "!wget https://ndownloader.figshare.com/files/20150468 -O ./example_data/processed_R2478.h5"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "metadata": {},
98 | "source": [
99 | "---\n",
100 | "---\n",
101 | "Now we can train the model. \n",
102 | "\n",
103 | "The following command uses 5 cross validations to train the models and stores weights in HDF5 files"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "# Define loss functions and train model\n",
113 | "loss_functions = {'position' : 'euclidean_loss', \n",
114 | " 'head_direction' : 'cyclical_mae_rad', \n",
115 | " 'speed' : 'mae'}\n",
116 | "loss_weights = {'position' : 1, \n",
117 | " 'head_direction' : 25, \n",
118 | " 'speed' : 2}\n",
119 | "deepinsight.train.run_from_path(fp_deepinsight, loss_functions, loss_weights)"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "metadata": {
126 | "scrolled": true
127 | },
128 | "outputs": [],
129 | "source": [
130 | "# Get loss and shuffled loss for influence plot, both is also stored back to HDF5 file\n",
131 | "losses, output_predictions, indices = deepinsight.analyse.get_model_loss(fp_deepinsight,\n",
132 | " stepsize=10)\n",
133 | "shuffled_losses = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=1,\n",
134 | " stepsize=10)"
135 | ]
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "metadata": {},
140 | "source": [
141 | "---\n",
142 | "---\n",
143 | "Above line calculates the loss and shuffled loss across the full experiment and writes it back to the HDF5 file.\n",
144 | "\n",
145 | "Below command visualizes the influence across different frequency bands for all samples\n",
146 | "\n",
147 | "Note that Figure 3 in the manuscript shows influence across animals, while this plot shows the influence for one animal across the experiment"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "# Plot influence across behaviours\n",
157 | "deepinsight.visualize.plot_residuals(fp_deepinsight, frequency_spacing=2,\n",
158 | " output_names=['Position', 'Head Direction', 'Speed'])"
159 | ]
160 | },
161 | {
162 | "cell_type": "markdown",
163 | "metadata": {},
164 | "source": [
165 | "---\n",
166 | "---"
167 | ]
168 | }
169 | ],
170 | "metadata": {
171 | "kernelspec": {
172 | "display_name": "Python 3",
173 | "language": "python",
174 | "name": "python3"
175 | },
176 | "language_info": {
177 | "codemirror_mode": {
178 | "name": "ipython",
179 | "version": 3
180 | },
181 | "file_extension": ".py",
182 | "mimetype": "text/x-python",
183 | "name": "python",
184 | "nbconvert_exporter": "python",
185 | "pygments_lexer": "ipython3",
186 | "version": "3.7.10"
187 | }
188 | },
189 | "nbformat": 4,
190 | "nbformat_minor": 2
191 | }
192 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow-gpu
2 | numpy
3 | pandas
4 | joblib
5 | seaborn
6 | matplotlib
7 | h5py
8 | scipy
9 | ipython
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepInsight Toolbox
3 | © Markus Frey
4 | https://github.com/CYHSM/DeepInsight
5 | Licensed under MIT License
6 | """
7 | from setuptools import setup, find_packages
8 |
9 | long_description = open('README.md').read()
10 | with open('requirements.txt') as f:
11 | requirements = f.read().splitlines()
12 |
13 | setup(
14 | name='deepinsight',
15 | version='0.5',
16 | install_requires=requirements,
17 | author='Markus Frey',
18 | author_email='markus.frey1@gmail.com',
19 | description="A general framework for interpreting wide-band neural activity",
20 | long_description=long_description,
21 | url='https://github.com/CYHSM/DeepInsight/',
22 | license='MIT',
23 | classifiers=[
24 | 'Development Status :: 3 - Alpha',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Intended Audience :: Developers',
27 | 'Natural Language :: English',
28 | 'Operating System :: OS Independent',
29 | 'Programming Language :: Python',
30 | 'Programming Language :: Python :: 3',
31 | ],
32 | packages=find_packages(),
33 | package_data={
34 | "": ["*.p", "*.h5", "*.csv", "*.gif", "*.png", "*.txt"],
35 | },
36 | )
37 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CYHSM/DeepInsight/e5a66be5dc3c671c37bd30ddf8f1f8ebae78ed2c/tests/__init__.py
--------------------------------------------------------------------------------
/tests/run_test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import h5py
4 | import deepinsight
5 |
6 | import numpy as np
7 | import unittest
8 | unittest.TestLoader.sortTestMethodsUsing = None
9 |
10 |
11 | class TestDeepInsight(unittest.TestCase):
12 | """Simple Testing Class"""
13 |
14 | def tearDown(self):
15 | time.sleep(0.1)
16 |
17 | def setUp(self):
18 | unittest.TestCase.setUp(self)
19 | np.random.seed(0)
20 | self.fp_deepinsight_folder = os.getcwd() + '/tests/test_files/'
21 | self.fp_deepinsight = self.fp_deepinsight_folder + 'test.h5'
22 | if os.path.exists(self.fp_deepinsight):
23 | os.remove(self.fp_deepinsight)
24 | else:
25 | os.makedirs(self.fp_deepinsight_folder)
26 | self.input_length = int(3e5)
27 | self.input_channels = 5
28 | self.sampling_rate = 30000
29 | self.input_output_ratio = 100
30 | self.average_window = 10
31 |
32 | self.rand_input = np.sin(np.random.rand(
33 | int(self.input_length), self.input_channels))
34 | self.rand_input_timesteps = np.arange(0, self.input_length)
35 | self.rand_output = np.random.rand(
36 | self.input_length // self.input_output_ratio)
37 | self.rand_timesteps = np.arange(
38 | 0, self.input_length, self.input_output_ratio)
39 |
40 | def test_fullrun(self):
41 | """
42 | Tests wavelet transformation of random signal
43 | """
44 | # Transform raw data to frequency domain
45 | deepinsight.preprocess.preprocess_input(
46 | self.fp_deepinsight, self.rand_input, sampling_rate=self.sampling_rate, average_window=self.average_window)
47 | hdf5_file = h5py.File(self.fp_deepinsight, mode='r')
48 | # Get wavelets from hdf5 file
49 | input_wavelets = hdf5_file['inputs/wavelets']
50 | # Check statistics of wavelets
51 | np.testing.assert_almost_equal(np.mean(input_wavelets), 0.048329710)
52 | np.testing.assert_almost_equal(np.std(input_wavelets), 0.04667989)
53 | np.testing.assert_almost_equal(np.median(input_wavelets), 0.03440293)
54 | np.testing.assert_almost_equal(np.max(input_wavelets), 0.60365933)
55 | np.testing.assert_almost_equal(np.min(input_wavelets), 3.78198024e-08)
56 | hdf5_file.close()
57 |
58 | # Prepare outputs
59 | deepinsight.preprocess.preprocess_output(
60 | self.fp_deepinsight, self.rand_input_timesteps, self.rand_output, self.rand_timesteps, average_window=self.average_window)
61 |
62 | # Define loss functions and train model
63 | loss_functions = {'aligned': 'mse'}
64 | loss_weights = {'aligned': 1}
65 | user_opts = {'epochs': 2, 'steps_per_epoch': 10,
66 | 'validation_steps': 10, 'log_output': False, 'save_model': True}
67 |
68 | deepinsight.train.run_from_path(
69 | self.fp_deepinsight, loss_functions, loss_weights, user_opts)
70 |
71 | # Get loss and shuffled loss for influence plot, both is also stored back to HDF5 file
72 | losses, output_predictions, indices, output_real = deepinsight.analyse.get_model_loss(
73 | self.fp_deepinsight, stepsize=10)
74 |
75 | shuffled_losses = deepinsight.analyse.get_shuffled_model_loss(
76 | self.fp_deepinsight, axis=1, stepsize=10)
77 |
78 |
79 | if __name__ == '__main__':
80 | unittest.main(warnings='ignore')
81 |
--------------------------------------------------------------------------------
/tests/tests.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "seed_value = 0\n",
10 | "\n",
11 | "# Import DeepInsight\n",
12 | "import sys\n",
13 | "sys.path.insert(0, \"/home/marx/Documents/Github/DeepInsight\")\n",
14 | "import deepinsight\n",
15 | "# Choose GPU\n",
16 | "import os\n",
17 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
18 | "os.environ['PYTHONHASHSEED']=str(0)\n",
19 | "import tensorflow as tf\n",
20 | "tf.random.set_seed(seed_value)\n",
21 | "# Also numpy random generator\n",
22 | "import numpy as np\n",
23 | "np.random.seed(seed_value)\n",
24 | "\n",
25 | "import numpy as np\n",
26 | "import h5py\n",
27 | "%load_ext autoreload\n",
28 | "%autoreload 2"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "#%run run_test.py"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": []
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "fp_deepinsight = './test_files/test.h5'\n",
54 | "if os.path.exists(fp_deepinsight):\n",
55 | " os.remove(fp_deepinsight)\n",
56 | "input_length = int(3e5)\n",
57 | "input_channels = 5\n",
58 | "sampling_rate = 30000\n",
59 | "input_output_ratio = 100\n",
60 | "\n",
61 | "np.random.seed(0)\n",
62 | "rand_input = np.sin(np.random.rand(int(input_length), input_channels))\n",
63 | "rand_input_timesteps = np.arange(0, input_length)\n",
64 | "rand_output = np.random.rand(input_length // input_output_ratio)\n",
65 | "rand_timesteps = np.arange(0, input_length, input_output_ratio)\n",
66 | "\n",
67 | "print(rand_input[0,0])\n",
68 | "print(rand_input_timesteps[0:10])\n",
69 | "print(rand_output[0])\n",
70 | "print(rand_timesteps[0:10])"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "# Transform raw data to frequency domain\n",
80 | "deepinsight.preprocess.preprocess_input(fp_deepinsight, rand_input, sampling_rate=sampling_rate, average_window=10)\n",
81 | "\n",
82 | "# Test cases\n",
83 | "hdf5_file = h5py.File(fp_deepinsight, mode='r')\n",
84 | "# Get size of wavelets\n",
85 | "input_wavelets = hdf5_file['inputs/wavelets']\n",
86 | "# Check statistics of wavelets\n",
87 | "np.testing.assert_almost_equal(np.mean(input_wavelets), 0.048329726)\n",
88 | "np.testing.assert_almost_equal(np.std(input_wavelets), 0.032383125)\n",
89 | "np.testing.assert_almost_equal(np.median(input_wavelets), 0.04608967)\n",
90 | "np.testing.assert_almost_equal(np.max(input_wavelets), 0.40853173)\n",
91 | "np.testing.assert_almost_equal(np.min(input_wavelets), 1.6544704e-05)\n",
92 | "hdf5_file.close()\n"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "print('Mean {:.10}, Std {:.10}, Median {:.10}, Max {:.10}, Min {:.10}'.format(np.mean(input_wavelets), np.std(input_wavelets), np.median(input_wavelets), np.max(input_wavelets), np.min(input_wavelets)))"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "metadata": {},
108 | "outputs": [],
109 | "source": []
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "# Prepare outputs\n",
118 | "deepinsight.preprocess.preprocess_output(fp_deepinsight, rand_input_timesteps, rand_output,\n",
119 | " rand_timesteps)"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "# Define loss functions and train model\n",
129 | "loss_functions = {'output_aligned' : 'mse'}\n",
130 | "loss_weights = {'output_aligned' : 1}\n",
131 | "user_opts = {'epochs' : 2, 'steps_per_epoch' : 10, 'validation_steps' : 10, 'log_output' : False, 'save_model' : False}\n",
132 | "\n",
133 | "deepinsight.train.run_from_path(fp_deepinsight, loss_functions, loss_weights, user_opts)"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": null,
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "# Get loss and shuffled loss for influence plot, both is also stored back to HDF5 file\n",
143 | "losses, output_predictions, indices = deepinsight.analyse.get_model_loss(fp_deepinsight, stepsize=10)\n",
144 | "\n",
145 | "# Test cases\n",
146 | "np.testing.assert_almost_equal(losses[-1], 1.0168755e-05)\n",
147 | "np.testing.assert_almost_equal(losses[0], 0.53577816)\n",
148 | "np.testing.assert_almost_equal(np.mean(losses), 0.09069238)\n",
149 | "np.testing.assert_almost_equal(np.std(losses), 0.13594063)\n",
150 | "np.testing.assert_almost_equal(np.median(losses), 0.045781307)\n",
151 | "np.testing.assert_almost_equal(np.max(losses), 0.53577816)\n",
152 | "np.testing.assert_almost_equal(np.min(losses), 1.0168755e-05)"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "metadata": {},
159 | "outputs": [],
160 | "source": []
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | "shuffled_losses = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=1,stepsize=10)\n",
169 | "\n",
170 | "# Test cases\n",
171 | "np.testing.assert_almost_equal(np.mean(shuffled_losses), 0.09304095)\n",
172 | "np.testing.assert_almost_equal(np.std(shuffled_losses), 0.13982493)\n",
173 | "np.testing.assert_almost_equal(np.median(shuffled_losses), 0.04165206)\n",
174 | "np.testing.assert_almost_equal(np.max(shuffled_losses), 0.7405345)\n",
175 | "np.testing.assert_almost_equal(np.min(shuffled_losses), 2.0834877e-07)"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "deepinsight.visualize.plot_residuals(fp_deepinsight, frequency_spacing=2,\n",
185 | " output_names=['output_aligned'])"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": null,
191 | "metadata": {},
192 | "outputs": [],
193 | "source": []
194 | }
195 | ],
196 | "metadata": {
197 | "kernelspec": {
198 | "display_name": "Python 3.7.6 64-bit",
199 | "language": "python",
200 | "name": "python37664bit5fa017aec819437bacf63081b14c694c"
201 | },
202 | "language_info": {
203 | "codemirror_mode": {
204 | "name": "ipython",
205 | "version": 3
206 | },
207 | "file_extension": ".py",
208 | "mimetype": "text/x-python",
209 | "name": "python",
210 | "nbconvert_exporter": "python",
211 | "pygments_lexer": "ipython3",
212 | "version": "3.7.10"
213 | }
214 | },
215 | "nbformat": 4,
216 | "nbformat_minor": 2
217 | }
218 |
--------------------------------------------------------------------------------