├── requirements.txt
├── .idea
├── encodings.xml
├── vcs.xml
├── .gitignore
├── misc.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── auditory-eeg-dataset.iml
├── config.json
├── technical_validation
├── experiments
│ ├── create_environment.sh
│ ├── regression_vlaai.py
│ ├── match_mismatch_dilated_convolutional_model.py
│ ├── regression_linear_backward_model.py
│ └── regression_linear_forward_model.py
├── models
│ ├── linear.py
│ ├── dilated_convolutional_model.py
│ └── vlaai.py
└── util
│ ├── split_and_normalize.py
│ ├── plot_results.py
│ └── dataset_generator.py
├── download_code
├── README.md
├── download_script.py
└── __init__.py
├── preprocessing_code
├── sparrKULee.yaml
└── sparrKULee.py
├── .gitignore
└── README.md
/requirements.txt:
--------------------------------------------------------------------------------
1 | brian2
2 | brian2hears
3 | numpy
4 | scipy
5 | mne
6 | librosa
7 | pandas
8 | pybids
9 | seaborn
10 | brain_pipe>=0.0.3
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_folder": "null",
3 | "derivatives_folder": "derivatives",
4 | "preprocessed_eeg_folder": "preprocessed_eeg",
5 | "preprocessed_stimuli_folder": "preprocessed_stimuli",
6 | "split_folder": "split_data"
7 | }
8 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/technical_validation/experiments/create_environment.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #create the necesary environment variables
4 | source ~/.bashrc
5 | source ~/.local/bin/load_conda_environment.sh bids_preprocessing
6 |
7 | printenv
8 |
9 | # Fix for unrecognized 'CUDA0' CUDA_VISIBLE_DEVICES
10 | export CUDA_VISIBLE_DEVICES=$(echo $CUDA_VISIBLE_DEVICES | sed -e 's/CUDA//g' | sed -e 's/\://g')
11 | export PYTHONPATH="/esat/spchtemp/scratch/baccou/auditory-eeg-dataset/"
12 | echo "$@"
13 |
14 | #run the original
15 | $@
16 |
17 |
--------------------------------------------------------------------------------
/.idea/auditory-eeg-dataset.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/download_code/README.md:
--------------------------------------------------------------------------------
1 | Tools/info to download the dataset
2 | ==================================
3 |
4 | Code to make downloading of the dataset more convenient.
5 |
6 | You can download the full dataset in `.zip` format [from the KU Leuven RDR site by clicking this link.](https://rdr.kuleuven.be/api/access/dataset/:persistentId/?persistentId=doi:10.48804/K3VSND)
7 |
8 | If you want more control over what you download and don't want to fully restart when downloading fails, you can use [download_script.py](./download_script.py) to download the dataset.
9 |
10 | The code in [download_script.py](./download_script.py) is meant to be used to download [the auditory EEG dataset from the RDR website](https://rdr.kuleuven.be/dataset.xhtml?persistentId=doi:10.48804/K3VSND), but can (possibly) be used to download other datasets from DataVerse servers.
11 |
12 | # Usage
13 |
14 | ```bash
15 | python download_script.py --help
16 | ```
17 |
18 | ```text
19 | usage: download_script.py [-h] [--server SERVER] [--dataset-id DATASET_ID] [--overwrite] [--skip-checksum] [--multiprocessing MULTIPROCESSING] [--subset {full,preprocessed,stimuli}] download_directory
20 |
21 | Download the auditory EEG dataset from RDR.
22 |
23 | positional arguments:
24 | download_directory Path to download the dataset to.
25 |
26 | options:
27 | -h, --help show this help message and exit
28 | --server SERVER The server to download the dataset from. Default: "rdr.kuleuven.be"
29 | --dataset-id DATASET_ID
30 | The dataset ID to download. Default: "doi:10.48804/K3VSND"
31 | --overwrite Overwrite existing files.
32 | --skip-checksum Whether to skip checksums.
33 | --multiprocessing MULTIPROCESSING
34 | Number of cores to use for multiprocessing. Default: -1 (all cores), set to 0 or 1 to disable multiprocessing.
35 | --subset {full,preprocessed,stimuli}
36 | Download only a subset of the dataset. "full" downloads the full dataset, "preprocessed" downloads only the preprocessed data and "stimuli" downloads only the stimuli files. Default: full
37 |
38 | ```
39 |
--------------------------------------------------------------------------------
/download_code/download_script.py:
--------------------------------------------------------------------------------
1 | """Script to download the auditory EEG dataset from RDR."""
2 | import argparse
3 | import os.path
4 |
5 | from download_code import DataverseDownloader, DataverseParser
6 |
7 | if __name__ == "__main__":
8 | download_dir = os.path.dirname(os.path.abspath(__file__))
9 |
10 | parser = argparse.ArgumentParser(
11 | description="Download the auditory EEG dataset from RDR."
12 | )
13 | parser.add_argument(
14 | "--server",
15 | default="rdr.kuleuven.be",
16 | help='The server to download the dataset from. Default: "rdr.kuleuven.be"',
17 | )
18 | parser.add_argument(
19 | "--dataset-id",
20 | default="doi:10.48804/K3VSND",
21 | help='The dataset ID to download. Default: "doi:10.48804/K3VSND"',
22 | )
23 | parser.add_argument(
24 | "--overwrite", action="store_true", help="Overwrite existing files."
25 | )
26 | parser.add_argument(
27 | "--skip-checksum", action="store_true", help="Whether to skip checksums."
28 | )
29 | parser.add_argument(
30 | "--multiprocessing",
31 | type=int,
32 | default=-1,
33 | help="Number of cores to use for multiprocessing. "
34 | "Default: -1 (all cores), set to 0 or 1 to disable multiprocessing.",
35 | )
36 | parser.add_argument(
37 | "--subset",
38 | choices=["full", "preprocessed", "stimuli"],
39 | default="full",
40 | help='Download only a subset of the dataset. '
41 | '"full" downloads the full dataset, '
42 | '"preprocessed" downloads only the preprocessed data and '
43 | '"stimuli" downloads only the stimuli files. Default: full',
44 | )
45 | parser.add_argument(
46 | "download_directory", type=str, help="Path to download the dataset to."
47 | )
48 |
49 | args = parser.parse_args()
50 |
51 | if args.subset == "full":
52 |
53 | def filter_fn(path, file_id):
54 | return True
55 |
56 | elif args.subset == "preprocessed":
57 |
58 | def filter_fn(path, file_id):
59 | return path.startswith("derivatives/")
60 |
61 | elif args.subset == "stimuli":
62 |
63 | def filter_fn(path, file_id):
64 | return path.startswith("stimuli/")
65 |
66 | else:
67 | raise ValueError(f"Unknown subset {args.subset}")
68 |
69 | dataverse_parser = DataverseParser(args.server)
70 | file_id_mapping = dataverse_parser(args.dataset_id)
71 | downloader = DataverseDownloader(
72 | args.download_directory,
73 | args.server,
74 | overwrite=args.overwrite,
75 | multiprocessing=args.multiprocessing,
76 | check_md5=not args.skip_checksum,
77 | )
78 | print(
79 | f"Starting download of set {args.subset} from {args.server} to "
80 | f"{args.download_directory}... (options: overwrite={args.overwrite}, "
81 | f"multiprocessing={args.multiprocessing})"
82 | )
83 | print(f"This might take a while...")
84 | downloader(file_id_mapping, filter_fn=filter_fn)
85 |
--------------------------------------------------------------------------------
/preprocessing_code/sparrKULee.yaml:
--------------------------------------------------------------------------------
1 | dataloaders:
2 | - name: sparrkulee_eeg_loader
3 | callable: GlobLoader
4 | glob_patterns:
5 | - {{ dataset_folder }}/sub-*/*/eeg/*_task-listeningActive_*.bdf*
6 | key: data_path
7 |
8 | pipelines:
9 | - callable: DefaultPipeline
10 | data_from: sparrkulee_eeg_loader
11 | steps:
12 | - callable: LinkStimulusToBrainResponse
13 | stimulus_data:
14 | callable: DefaultPipeline
15 | steps:
16 | - callable: LoadStimuli
17 | load_fn:
18 | callable: temp_stimulus_load_fn
19 | is_pointer: true
20 | - callable: GammatoneEnvelope
21 | # Uncomment if mel is not needed
22 | - callable: LibrosaMelSpectrogram
23 | power_factor: 0.6
24 | librosa_kwargs:
25 | callable: SparrKULeeSpectrogramKwargs
26 | - callable: ResamplePoly
27 | target_frequency: 64
28 | data_key: envelope_data
29 | sampling_frequency_key: stimulus_sr
30 | - callable: DefaultSave
31 | root_dir: {{ dataset_folder }}/{{ derivatives_folder}}/{{ preprocessed_stimuli_dir }}
32 | to_save:
33 | envelope: envelope_data
34 | # Uncomment if mel is not needed
35 | mel: spectrogram_data
36 | - callable: DefaultSave
37 | root_dir: {{ dataset_folder }}/{{ derivatives_folder}}/{{ preprocessed_stimuli_dir }}
38 | overwrite: false
39 | grouper:
40 | callable: BIDSStimulusGrouper
41 | bids_root: {{ dataset_folder }}
42 | mapping:
43 | stim_file: stimulus_path
44 | trigger_file: trigger_path
45 | subfolders:
46 | - stimuli
47 | - eeg
48 | - callable: LoadEEGNumpy
49 | unit_multiplier: 1000000
50 | channels_to_select:
51 | {% for channel in range(64) %}
52 | - {{ channel }}
53 | {% endfor %}
54 |
55 | - callable: SosFiltFilt
56 | filter_:
57 | callable: scipy.signal.butter
58 | N: 1
59 | Wn: 0.5
60 | btype: highpass
61 | fs: 1024
62 | output: sos
63 | emulate_matlab: true
64 | axis: 1
65 | - callable: InterpolateArtifacts
66 | - callable: AlignPeriodicBlockTriggers
67 | brain_trigger_processing_fn:
68 | callable: biosemi_trigger_processing_fn
69 | is_pointer: true
70 | - callable: SplitEpochs
71 | - callable: ArtifactRemovalMWF
72 | - callable: CommonAverageRereference
73 | - callable: ResamplePoly
74 | target_frequency: 64
75 | axis: 1
76 | - callable: DefaultSave
77 | root_dir: {{ dataset_folder }}/{{ derivatives_folder}}/{{ preprocessed_eeg_dir }}
78 | to_save:
79 | eeg: data
80 | overwrite: false
81 | filename_fn:
82 | callable: bids_filename_fn
83 | is_pointer: true
84 | clear_output: true
85 |
86 | config:
87 | parser:
88 | extra_paths:
89 | - {{ __filedir__ }}/sparrKULee.py
90 | logging:
91 | log_path: {{ __filedir__ }}/sparrKULee_{datetime}.log
92 |
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/technical_validation/models/linear.py:
--------------------------------------------------------------------------------
1 | """ This module contains linear backward model"""
2 | import tensorflow as tf
3 |
4 |
5 |
6 |
7 | @tf.function
8 | def pearson_loss_cut(y_true, y_pred, axis=1):
9 | """Pearson loss function.
10 |
11 | Parameters
12 | ----------
13 | y_true: tf.Tensor
14 | True values. Shape is (batch_size, time_steps, n_features)
15 | y_pred: tf.Tensor
16 | Predicted values. Shape is (batch_size, time_steps, n_features)
17 |
18 | Returns
19 | -------
20 | tf.Tensor
21 | Pearson loss.
22 | Shape is (batch_size, 1, n_features)
23 | """
24 | return -pearson_tf(y_true[:, : tf.shape(y_pred)[1], :], y_pred, axis=axis)
25 |
26 |
27 | @tf.function
28 | def pearson_metric_cut(y_true, y_pred, axis=1):
29 | """Pearson metric function.
30 |
31 | Parameters
32 | ----------
33 | y_true: tf.Tensor
34 | True values. Shape is (batch_size, time_steps, n_features)
35 | y_pred: tf.Tensor
36 | Predicted values. Shape is (batch_size, time_steps, n_features)
37 |
38 | Returns
39 | -------
40 | tf.Tensor
41 | Pearson metric.
42 | Shape is (batch_size, 1, n_features)
43 | """
44 | return tf.reduce_mean( pearson_tf(y_true[:, : tf.shape(y_pred)[1], :], y_pred, axis=axis), axis=-1)
45 |
46 | @tf.function
47 | def pearson_metric_cut_not_av(y_true, y_pred, axis=1):
48 | """Pearson metric function.
49 |
50 | Parameters
51 | ----------
52 | y_true: tf.Tensor
53 | True values. Shape is (batch_size, time_steps, n_features)
54 | y_pred: tf.Tensor
55 | Predicted values. Shape is (batch_size, time_steps, n_features)
56 |
57 | Returns
58 | -------
59 | tf.Tensor
60 | Pearson metric.
61 | Shape is (batch_size, 1, n_features)
62 | """
63 | return pearson_tf(y_true[:, : tf.shape(y_pred)[1], :], y_pred, axis=axis)
64 |
65 |
66 | def simple_linear_model(integration_window=32, nb_filters=1, nb_channels=64, metric= pearson_metric_cut):
67 | inp = tf.keras.layers.Input(
68 | (
69 | None,
70 | nb_channels,
71 | )
72 | )
73 | out = tf.keras.layers.Conv1D(nb_filters, integration_window)(inp)
74 | model = tf.keras.models.Model(inputs=[inp], outputs=[out])
75 | model.compile(
76 | tf.keras.optimizers.Adam(),
77 | loss=pearson_loss_cut,
78 | metrics=[metric]
79 | )
80 | return model
81 |
82 |
83 |
84 | def pearson_tf(y_true, y_pred, axis=1):
85 | """Pearson correlation function implemented in tensorflow.
86 |
87 | Parameters
88 | ----------
89 | y_true: tf.Tensor
90 | Ground truth labels. Shape is (batch_size, time_steps, n_features)
91 | y_pred: tf.Tensor
92 | Predicted labels. Shape is (batch_size, time_steps, n_features)
93 | axis: int
94 | Axis along which to compute the pearson correlation. Default is 1.
95 |
96 | Returns
97 | -------
98 | tf.Tensor
99 | Pearson correlation.
100 | Shape is (batch_size, 1, n_features) if axis is 1.
101 | """
102 | # Compute the mean of the true and predicted values
103 | y_true_mean = tf.reduce_mean(y_true, axis=axis, keepdims=True)
104 | y_pred_mean = tf.reduce_mean(y_pred, axis=axis, keepdims=True)
105 |
106 | # Compute the numerator and denominator of the pearson correlation
107 | numerator = tf.reduce_sum(
108 | (y_true - y_true_mean) * (y_pred - y_pred_mean),
109 | axis=axis,
110 | keepdims=True,
111 | )
112 | std_true = tf.reduce_sum(tf.square(y_true - y_true_mean), axis=axis, keepdims=True)
113 | std_pred = tf.reduce_sum(tf.square(y_pred - y_pred_mean), axis=axis, keepdims=True)
114 | denominator = tf.sqrt(std_true * std_pred)
115 |
116 | # Compute the pearson correlation
117 | return tf.math.divide_no_nan(numerator, denominator)
118 |
--------------------------------------------------------------------------------
/technical_validation/models/dilated_convolutional_model.py:
--------------------------------------------------------------------------------
1 | """Default dilation model."""
2 | import tensorflow as tf
3 |
4 |
5 | def dilation_model(
6 | time_window=None,
7 | eeg_input_dimension=64,
8 | env_input_dimension=1,
9 | layers=3,
10 | kernel_size=3,
11 | spatial_filters=8,
12 | dilation_filters=16,
13 | activation="relu",
14 | compile=True,
15 | inputs=tuple(),
16 | ):
17 | """Convolutional dilation model.
18 |
19 | Code was taken and adapted from
20 | https://github.com/exporl/eeg-matching-eusipco2020
21 |
22 | Parameters
23 | ----------
24 | time_window : int or None
25 | Segment length. If None, the model will accept every time window input
26 | length.
27 | eeg_input_dimension : int
28 | number of channels of the EEG
29 | env_input_dimension : int
30 | dimemsion of the stimulus representation.
31 | if stimulus == envelope, env_input_dimension =1
32 | if stimulus == mel, env_input_dimension =28
33 | layers : int
34 | Depth of the network/Number of layers
35 | kernel_size : int
36 | Size of the kernel for the dilation convolutions
37 | spatial_filters : int
38 | Number of parallel filters to use in the spatial layer
39 | dilation_filters : int
40 | Number of parallel filters to use in the dilation layers
41 | activation : str or list or tuple
42 | Name of the non-linearity to apply after the dilation layers
43 | or list/tuple of different non-linearities
44 | compile : bool
45 | If model should be compiled
46 | inputs : tuple
47 | Alternative inputs
48 |
49 | Returns
50 | -------
51 | tf.Model
52 | The dilation model
53 |
54 |
55 | References
56 | ----------
57 | Accou, B., Jalilpour Monesi, M., Montoya, J., Van hamme, H. & Francart, T.
58 | Modeling the relationship between acoustic stimulus and EEG with a dilated
59 | convolutional neural network. In 2020 28th European Signal Processing
60 | Conference (EUSIPCO), 1175–1179, DOI: 10.23919/Eusipco47968.2020.9287417
61 | (2021). ISSN: 2076-1465.
62 |
63 | Accou, B., Monesi, M. J., hamme, H. V. & Francart, T.
64 | Predicting speech intelligibility from EEG in a non-linear classification
65 | paradigm. J. Neural Eng. 18, 066008, DOI: 10.1088/1741-2552/ac33e9 (2021).
66 | Publisher: IOP Publishing
67 | """
68 | # If different inputs are required
69 | if len(inputs) == 3:
70 | eeg, env1, env2 = inputs[0], inputs[1], inputs[2]
71 | else:
72 | eeg = tf.keras.layers.Input(shape=[time_window, eeg_input_dimension])
73 | env1 = tf.keras.layers.Input(shape=[time_window, env_input_dimension])
74 | env2 = tf.keras.layers.Input(shape=[time_window, env_input_dimension])
75 |
76 | # Activations to apply
77 | if isinstance(activation, str):
78 | activations = [activation] * layers
79 | else:
80 | activations = activation
81 |
82 | env_proj_1 = env1
83 | env_proj_2 = env2
84 | # Spatial convolution
85 | eeg_proj_1 = tf.keras.layers.Conv1D(spatial_filters, kernel_size=1)(eeg)
86 |
87 | # Construct dilation layers
88 | for layer_index in range(layers):
89 | # dilation on EEG
90 | eeg_proj_1 = tf.keras.layers.Conv1D(
91 | dilation_filters,
92 | kernel_size=kernel_size,
93 | dilation_rate=kernel_size**layer_index,
94 | strides=1,
95 | activation=activations[layer_index],
96 | )(eeg_proj_1)
97 |
98 | # Dilation on envelope data, share weights
99 | env_proj_layer = tf.keras.layers.Conv1D(
100 | dilation_filters,
101 | kernel_size=kernel_size,
102 | dilation_rate=kernel_size**layer_index,
103 | strides=1,
104 | activation=activations[layer_index],
105 | )
106 | env_proj_1 = env_proj_layer(env_proj_1)
107 | env_proj_2 = env_proj_layer(env_proj_2)
108 |
109 | # Comparison
110 | cos1 = tf.keras.layers.Dot(1, normalize=True)([eeg_proj_1, env_proj_1])
111 | cos2 = tf.keras.layers.Dot(1, normalize=True)([eeg_proj_1, env_proj_2])
112 |
113 | # Classification
114 | out = tf.keras.layers.Dense(1, activation="sigmoid")(
115 | tf.keras.layers.Flatten()(tf.keras.layers.Concatenate()([cos1, cos2]))
116 | )
117 |
118 | model = tf.keras.Model(inputs=[eeg, env1, env2], outputs=[out])
119 |
120 | if compile:
121 | model.compile(
122 | optimizer=tf.keras.optimizers.Adam(),
123 | metrics=["acc"],
124 | loss=["binary_crossentropy"],
125 | )
126 | print(model.summary())
127 | return model
128 |
--------------------------------------------------------------------------------
/technical_validation/util/split_and_normalize.py:
--------------------------------------------------------------------------------
1 | """Split data in sets and normalize (per recording)."""
2 | import glob
3 | import json
4 | import os
5 | import pickle
6 | import numpy as np
7 |
8 |
9 | if __name__ == "__main__":
10 |
11 | # Arguments for splitting and normalizing
12 | speech_features = ['envelope']
13 | splits = [80, 10, 10]
14 | split_names = ['train', 'val', 'test']
15 | overwrite = False
16 |
17 | # Calculate the split fraction
18 | split_fractions = [x/sum(splits) for x in splits]
19 |
20 | # Get the path to the config file
21 | main_folder = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
22 | config_path = os.path.join(main_folder, 'config.json')
23 |
24 | # Load the config
25 | with open(config_path) as fp:
26 | config = json.load(fp)
27 |
28 | # Construct the necessary paths
29 | processed_eeg_folder = os.path.join(config["dataset_folder"], config["derivatives_folder"],config["preprocessed_eeg_folder"])
30 | processed_stimuli_folder = os.path.join(config["dataset_folder"],config["derivatives_folder"], config["preprocessed_stimuli_folder"])
31 | split_data_folder = os.path.join(config["dataset_folder"], config["derivatives_folder"], config["split_folder"])
32 |
33 | # Create the output folder
34 | os.makedirs(split_data_folder, exist_ok=True)
35 |
36 | # Find all subjects
37 | all_subjects = glob.glob(os.path.join(processed_eeg_folder, "sub*"))
38 | nb_subjects = len(all_subjects)
39 | print(f"Found {nb_subjects} subjects to split/normalize")
40 |
41 | # Loop over subjects
42 | for subject_index, subject_path in enumerate(all_subjects):
43 | subject = os.path.basename(subject_path)
44 | print(f"Starting with subject {subject} ({subject_index + 1}/{nb_subjects})...")
45 | # Find all recordings
46 | all_recordings = glob.glob(os.path.join(subject_path, "*", "*.npy"))
47 | print(f"\tFound {len(all_recordings)} recordings for subject {subject}.")
48 | # Loop over recordings
49 | for recording_index, recording in enumerate(all_recordings):
50 | print(f"\tStarting with recording {recording} ({recording_index + 1}/{len(all_recordings)})...")
51 |
52 | # Load EEG from disk
53 | print(f"\t\tLoading EEG for {recording}")
54 | eeg = np.load(recording)
55 |
56 | # swap axes to have time as first dimension
57 | eeg = np.swapaxes(eeg, 0, 1)
58 |
59 | # keep only the 64 channels
60 | eeg = eeg[:, :64]
61 |
62 | # retrieve the stimulus name from the filename
63 | stimulus_filename = recording.split('_eeg.')[0].split('-audio-')[1]
64 |
65 | # Retrieve EEG data and pointer to the stimulus
66 | shortest_length = eeg.shape[0]
67 |
68 | # Create mapping between feature name and feature data
69 | all_data_for_recording = {"eeg": eeg}
70 |
71 | # Find corresponding stimuli for the EEG recording
72 | for feature_name in speech_features:
73 | # Load feature from disk
74 | print(f"\t\tLoading {feature_name} for recording {recording} ")
75 | stimulus_feature_path = os.path.join(
76 | processed_stimuli_folder,
77 | stimulus_filename + "_" + feature_name + ".npy",
78 | )
79 | feature = np.load(stimulus_feature_path)
80 | # Calculate the shortest length
81 | shortest_length = min(feature.shape[0], shortest_length)
82 | # Update all_data_for_recording
83 | all_data_for_recording[feature_name] = feature
84 |
85 | # Do the actual splitting
86 | print(f"\t\tSplitting/normalizing recording {recording}...")
87 | for feature_name, feature in all_data_for_recording.items():
88 | start_index = 0
89 | feature_mean = None
90 | feature_std = None
91 |
92 | for split_name, split_fraction in zip(split_names, split_fractions):
93 | end_index = start_index + int(shortest_length * split_fraction)
94 |
95 | # Cut the feature to the shortest length
96 | cut_feature = feature[start_index:end_index, ...]
97 |
98 | # Normalize the feature
99 | if feature_mean is None:
100 | feature_mean = np.mean(cut_feature, axis=0)
101 | feature_std = np.std(cut_feature, axis=0)
102 | norm_feature = (cut_feature - feature_mean)/feature_std
103 |
104 | # Save the feature
105 | save_filename = f"{split_name}_-_{subject}_-_{stimulus_filename}_-_{feature_name}.npy"
106 | save_path = os.path.join(split_data_folder, save_filename)
107 | if not os.path.exists(save_path) or overwrite:
108 | np.save(save_path, cut_feature)
109 | else:
110 | print(f"\t\tSkipping {save_filename} because it already exists")
111 | start_index = end_index
112 |
113 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
163 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
164 |
165 | # User-specific stuff
166 | .idea/**/workspace.xml
167 | .idea/**/tasks.xml
168 | .idea/**/usage.statistics.xml
169 | .idea/**/dictionaries
170 | .idea/**/shelf
171 |
172 | # AWS User-specific
173 | .idea/**/aws.xml
174 |
175 | # Generated files
176 | .idea/**/contentModel.xml
177 |
178 | # Sensitive or high-churn files
179 | .idea/**/dataSources/
180 | .idea/**/dataSources.ids
181 | .idea/**/dataSources.local.xml
182 | .idea/**/sqlDataSources.xml
183 | .idea/**/dynamic.xml
184 | .idea/**/uiDesigner.xml
185 | .idea/**/dbnavigator.xml
186 |
187 | # Gradle
188 | .idea/**/gradle.xml
189 | .idea/**/libraries
190 |
191 | # Gradle and Maven with auto-import
192 | # When using Gradle or Maven with auto-import, you should exclude module files,
193 | # since they will be recreated, and may cause churn. Uncomment if using
194 | # auto-import.
195 | # .idea/artifacts
196 | # .idea/compiler.xml
197 | # .idea/jarRepositories.xml
198 | # .idea/modules.xml
199 | # .idea/*.iml
200 | # .idea/modules
201 | # *.iml
202 | # *.ipr
203 |
204 | # CMake
205 | cmake-build-*/
206 |
207 | # Mongo Explorer plugin
208 | .idea/**/mongoSettings.xml
209 |
210 | # File-based project format
211 | *.iws
212 |
213 | # IntelliJ
214 | out/
215 |
216 | # mpeltonen/sbt-idea plugin
217 | .idea_modules/
218 |
219 | # JIRA plugin
220 | atlassian-ide-plugin.xml
221 |
222 | # Cursive Clojure plugin
223 | .idea/replstate.xml
224 |
225 | # SonarLint plugin
226 | .idea/sonarlint/
227 |
228 | # Crashlytics plugin (for Android Studio and IntelliJ)
229 | com_crashlytics_export_strings.xml
230 | crashlytics.properties
231 | crashlytics-build.properties
232 | fabric.properties
233 |
234 | # Editor-based Rest Client
235 | .idea/httpRequests
236 |
237 | # Android studio 3.1+ serialized cache file
238 | .idea/caches/build_file_checksums.ser
239 |
240 |
241 | # No condor job files
242 | technical_validation/condor/
243 | *.job
244 | technical_validation/experiments/results*
245 | technical_validation/experiments/figures/
246 | # No tests with classified data
247 | tests/test_classified/*
248 | # No .py in progress files
249 | *.py~
--------------------------------------------------------------------------------
/technical_validation/experiments/regression_vlaai.py:
--------------------------------------------------------------------------------
1 | """Example experiment for the VLAAI model."""
2 | import glob
3 | import json
4 | import logging
5 | import os
6 | import tensorflow as tf
7 | import sys
8 |
9 |
10 | from technical_validation.models.vlaai import vlaai, pearson_loss, pearson_metric
11 | from technical_validation.util.dataset_generator import RegressionDataGenerator, create_tf_dataset
12 |
13 |
14 | def evaluate_model(model, test_dict):
15 | """Evaluate a model.
16 |
17 | Parameters
18 | ----------
19 | model: tf.keras.Model
20 | Model to evaluate.
21 | test_dict: dict
22 | Mapping between a subject and a tf.data.Dataset containing the test
23 | set for the subject.
24 |
25 | Returns
26 | -------
27 | dict
28 | Mapping between a subject and the loss/evaluation score on the test set
29 | """
30 | evaluation = {}
31 | for subject, ds_test in test_dict.items():
32 | logging.info(f"Scores for subject {subject}:")
33 | results = model.evaluate(ds_test, verbose=2)
34 | metrics = model.metrics_names
35 | evaluation[subject] = dict(zip(metrics, results))
36 | return evaluation
37 |
38 |
39 | if __name__ == "__main__":
40 |
41 | gpus = tf.config.list_physical_devices('GPU')
42 | print(gpus)
43 | # if gpus:
44 | # try:
45 | # # Currently, memory growth needs to be the same across GPUs
46 | # for gpu in gpus:
47 | # tf.config.experimental.set_memory_growth(gpu, True)
48 | # logical_gpus = tf.config.list_logical_devices('GPU')
49 | # print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
50 | # except RuntimeError as e:
51 | # # Memory growth must be set before GPUs have been initialized
52 | # print(e)
53 | # Parameters
54 | # Length of the decision window
55 | window_length = 10 * 64 # 10 seconds
56 | # Hop length between two consecutive decision windows
57 | hop_length = 64
58 | epochs = 100
59 | patience = 5
60 | batch_size = 64
61 | only_evaluate = False
62 | training_log_filename = "training_log_%d.csv" %window_length
63 | results_filename = 'eval_%d.json' % window_length
64 |
65 |
66 | # Get the path to the config gile
67 | experiments_folder = os.path.dirname(__file__)
68 | main_folder = os.path.dirname(os.path.dirname(experiments_folder))
69 | config_path = os.path.join(main_folder, 'config.json')
70 |
71 | # Load the config
72 | with open(config_path) as fp:
73 | config = json.load(fp)
74 |
75 | # Provide the path of the dataset
76 | # which is split already to train, val, test
77 |
78 | data_folder = os.path.join(config["dataset_folder"], config["derivatives"], config["split_folder"])
79 | stimulus_features = ["envelope"]
80 | features = ["eeg"] + stimulus_features
81 |
82 | # Create a directory to store (intermediate) results
83 | results_folder = os.path.join(experiments_folder, "results_vlaai")
84 | os.makedirs(results_folder, exist_ok=True)
85 |
86 | # create the model
87 | model = vlaai()
88 | model.compile(tf.keras.optimizers.Adam(), loss=pearson_loss, metrics=[pearson_metric])
89 | model_path = os.path.join(results_folder, "model_%d.h5" % window_length)
90 |
91 | if only_evaluate:
92 | model.load_weights(model_path)
93 | else:
94 |
95 | train_files = [x for x in glob.glob(os.path.join(data_folder, "train_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features ]
96 | # Create list of numpy array files
97 | train_generator = RegressionDataGenerator(train_files, window_length)
98 | dataset_train = create_tf_dataset(train_generator, window_length, None, hop_length, batch_size, data_types=(tf.float32, tf.float32), feature_dims=(64,1))
99 |
100 | # Create the generator for the validation set
101 | val_files = [x for x in glob.glob(os.path.join(data_folder, "val_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features ]
102 | val_generator = RegressionDataGenerator(val_files, window_length)
103 | dataset_val = create_tf_dataset(val_generator, window_length, None, hop_length, batch_size, data_types=(tf.float32, tf.float32), feature_dims=(64,1))
104 |
105 | # Train the model
106 | model.fit(
107 | dataset_train,
108 | epochs=epochs,
109 | validation_data=dataset_val,
110 | callbacks=[
111 | tf.keras.callbacks.ModelCheckpoint(model_path, save_best_only=True),
112 | tf.keras.callbacks.CSVLogger(os.path.join(results_folder, training_log_filename)),
113 | tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=True),
114 | ],
115 | )
116 |
117 | # Evaluate the model on test set
118 | # Create a dataset generator for each test subject
119 | test_files = [x for x in glob.glob(os.path.join(data_folder, "test_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features]
120 | # Get all different subjects from the test set
121 | subjects = list(set([os.path.basename(x).split("_-_")[1] for x in test_files]))
122 | datasets_test = {}
123 | results_filename = 'eval_pretrained_%d.json' % window_length
124 | # Create a generator for each subject
125 | for sub in subjects:
126 | files_test_sub = [f for f in test_files if sub in os.path.basename(f)]
127 |
128 | test_generator = RegressionDataGenerator(files_test_sub, window_length)
129 | datasets_test[sub] = create_tf_dataset(test_generator, window_length, None, hop_length, batch_size, data_types=(tf.float32, tf.float32), feature_dims=(64,1))
130 |
131 | # Evaluate the model
132 | evaluation = evaluate_model(model, datasets_test)
133 |
134 | # We can save our results in a json encoded file
135 | results_path = os.path.join(results_folder, results_filename)
136 | with open(results_path, "w") as fp:
137 | json.dump(evaluation, fp)
138 | logging.info(f"Results saved at {results_path}")
139 |
140 |
--------------------------------------------------------------------------------
/technical_validation/experiments/match_mismatch_dilated_convolutional_model.py:
--------------------------------------------------------------------------------
1 | """Example experiment for dilation model."""
2 | import glob
3 | import json
4 | import logging
5 | import os
6 | os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
7 | os.environ['TF_XLA_FLAGS'] = "--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit"
8 | import tensorflow as tf
9 | import sys
10 |
11 |
12 | from technical_validation.models.dilated_convolutional_model import dilation_model
13 | from technical_validation.util.dataset_generator import MatchMismatchDataGenerator, default_batch_equalizer_fn, create_tf_dataset
14 |
15 |
16 | def evaluate_model(model, test_dict):
17 | """Evaluate a model.
18 |
19 | Parameters
20 | ----------
21 | model: tf.keras.Model
22 | Model to evaluate.
23 | test_dict: dict
24 | Mapping between a subject and a tf.data.Dataset containing the test
25 | set for the subject.
26 |
27 | Returns
28 | -------
29 | dict
30 | Mapping between a subject and the loss/evaluation score on the test set
31 | """
32 | evaluation = {}
33 | for subject, ds_test in test_dict.items():
34 | logging.info(f"Scores for subject {subject}:")
35 | results = model.evaluate(ds_test, verbose=2)
36 | metrics = model.metrics_names
37 | evaluation[subject] = dict(zip(metrics, results))
38 | return evaluation
39 |
40 |
41 | if __name__ == "__main__":
42 | # Parameters
43 | # Length of the decision window
44 | window_length = 5 * 64 # 5 seconds
45 | # Hop length between two consecutive decision windows
46 | hop_length = 64
47 | # Number of samples (space) between end of matched speech and beginning of mismatched speech
48 | spacing = 64
49 | epochs = 100
50 | patience = 5
51 | batch_size = 64
52 |
53 | only_evaluate = False
54 |
55 | training_log_filename = "training_log.csv"
56 | results_filename = 'eval.json'
57 |
58 |
59 | # Get the path to the config gile
60 | experiments_folder = os.path.dirname(__file__)
61 | main_folder = os.path.dirname(os.path.dirname(experiments_folder))
62 | config_path = os.path.join(main_folder, 'config.json')
63 |
64 | # Load the config
65 | with open(config_path) as fp:
66 | config = json.load(fp)
67 |
68 | # Provide the path of the dataset
69 | # which is split already to train, val, test
70 | data_folder = os.path.join(config["dataset_folder"],config["derivatives"], config["split_folder"])
71 |
72 | # stimulus feature which will be used for training the model. Can be either 'envelope' ( dimension 1)
73 | stimulus_features = ["envelope"]
74 | stimulus_dimension = 1
75 |
76 | features = ["eeg"] + stimulus_features
77 |
78 | # Create a directory to store (intermediate) results
79 | results_folder = os.path.join(experiments_folder, "results_dilated_convolutional_model")
80 |
81 | os.makedirs(results_folder, exist_ok=True)
82 |
83 | # create dilation model
84 | model = dilation_model(time_window=window_length, eeg_input_dimension=64, env_input_dimension=stimulus_dimension)
85 | model_path = os.path.join(results_folder, "model.h5")
86 |
87 | if only_evaluate:
88 | model = tf.keras.models.load_model(model_path)
89 | else:
90 |
91 | train_files = [x for x in glob.glob(os.path.join(data_folder, "train_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features]
92 | # Create list of numpy array files
93 | train_generator = MatchMismatchDataGenerator(train_files, window_length, spacing=spacing)
94 | dataset_train = create_tf_dataset(train_generator, window_length, default_batch_equalizer_fn, hop_length, batch_size)
95 |
96 | # Create the generator for the validation set
97 | val_files = [x for x in glob.glob(os.path.join(data_folder, "val_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features]
98 | val_generator = MatchMismatchDataGenerator(val_files, window_length, spacing=spacing)
99 | dataset_val = create_tf_dataset(val_generator, window_length, default_batch_equalizer_fn, hop_length, batch_size)
100 |
101 | # Train the model
102 | model.fit(
103 | dataset_train,
104 | epochs=epochs,
105 | validation_data=dataset_val,
106 | callbacks=[
107 | tf.keras.callbacks.ModelCheckpoint(model_path, save_best_only=True),
108 | tf.keras.callbacks.CSVLogger(os.path.join(results_folder, training_log_filename)),
109 | tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=True),
110 | ],
111 | )
112 |
113 |
114 |
115 | # Evaluate the model on test set
116 | # Create a dataset generator for each test subject
117 | test_files = [x for x in glob.glob(os.path.join(data_folder, "test_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features ]
118 | # Get all different subjects from the test set
119 | subjects = list(set([os.path.basename(x).split("_-_")[1] for x in test_files]))
120 |
121 |
122 | test_dict = {}
123 | # evaluate on different window lengths
124 | window_lengths = [64,128, 256,320, 640, 20*64]
125 | for window_length in window_lengths:
126 | model = dilation_model(time_window=window_length, eeg_input_dimension=64,
127 | env_input_dimension=stimulus_dimension)
128 | model_path = os.path.join(results_folder, "model.h5")
129 | model.load_weights(model_path)
130 |
131 | datasets_test = {}
132 | # Create a generator for each subject
133 | for sub in subjects:
134 | files_test_sub = [f for f in test_files if sub in os.path.basename(f)]
135 | test_generator = MatchMismatchDataGenerator(files_test_sub, window_length, spacing=spacing)
136 | datasets_test[sub] = create_tf_dataset(test_generator, window_length, default_batch_equalizer_fn, hop_length, 1,)
137 |
138 | # Evaluate the model
139 | evaluation = evaluate_model(model, datasets_test)
140 |
141 | # We can save our results in a json encoded file
142 | results_path = os.path.join(results_folder, results_filename.split(".")[0] + f"_{window_length}"+ ".json")
143 | with open(results_path, "w") as fp:
144 | json.dump(evaluation, fp)
145 | logging.info(f"Results saved at {results_path}")
146 |
147 |
148 |
--------------------------------------------------------------------------------
/technical_validation/models/vlaai.py:
--------------------------------------------------------------------------------
1 | """Code to construct the VLAAI network.
2 | Code was extrcted from https://github.com/exporl/vlaai
3 | """
4 | import tensorflow as tf
5 |
6 |
7 | def extractor(
8 | filters=(256, 256, 256, 128, 128),
9 | kernels=(8,) * 5,
10 | input_channels=64,
11 | normalization_fn=lambda x: tf.keras.layers.LayerNormalization()(x),
12 | activation_fn=lambda x: tf.keras.layers.LeakyReLU()(x),
13 | name="extractor",
14 | ):
15 | """Construct the extractor model.
16 |
17 | Parameters
18 | ----------
19 | filters: Sequence[int]
20 | Number of filters for each layer.
21 | kernels: Sequence[int]
22 | Kernel size for each layer.
23 | input_channels: int
24 | Number of EEG channels in the input
25 | normalization_fn: Callable[[tf.Tensor], tf.Tensor]
26 | Function to normalize the contents of a tensor.
27 | activation_fn: Callable[[tf.Tensor], tf.Tensor]
28 | Function to apply an activation function to the contents of a tensor.
29 | name: str
30 | Name of the model.
31 |
32 | Returns
33 | -------
34 | tf.keras.models.Model
35 | The extractor model.
36 | """
37 | eeg = tf.keras.layers.Input((None, input_channels))
38 |
39 | x = eeg
40 |
41 | if len(filters) != len(kernels):
42 | raise ValueError("'filters' and 'kernels' must have the same length")
43 |
44 | # Add the convolutional layers
45 | for filter_, kernel in zip(filters, kernels):
46 | x = tf.keras.layers.Conv1D(filter_, kernel)(x)
47 | x = normalization_fn(x)
48 | x = activation_fn(x)
49 | x = tf.keras.layers.ZeroPadding1D((0, kernel - 1))(x)
50 |
51 | return tf.keras.models.Model(inputs=[eeg], outputs=[x], name=name)
52 |
53 |
54 | def output_context(
55 | filter_=64,
56 | kernel=32,
57 | input_channels=64,
58 | normalization_fn=lambda x: tf.keras.layers.LayerNormalization()(x),
59 | activation_fn=lambda x: tf.keras.layers.LeakyReLU()(x),
60 | name="output_context_model",
61 | ):
62 | """Construct the output context model.
63 |
64 | Parameters
65 | ----------
66 | filter_: int
67 | Number of filters for the convolutional layer.
68 | kernel: int
69 | Kernel size for the convolutional layer.
70 | input_channels: int
71 | Number of EEG channels in the input.
72 | normalization_fn: Callable[[tf.Tensor], tf.Tensor]
73 | Function to normalize the contents of a tensor.
74 | activation_fn: Callable[[tf.Tensor], tf.Tensor]
75 | Function to apply an activation function to the contents of a tensor.
76 | name: str
77 | Name of the model.
78 |
79 | Returns
80 | -------
81 | tf.keras.models.Model
82 | The output context model.
83 | """
84 | inp = tf.keras.layers.Input((None, input_channels))
85 | x = tf.keras.layers.ZeroPadding1D((kernel - 1, 0))(inp)
86 | x = tf.keras.layers.Conv1D(filter_, kernel)(x)
87 | x = normalization_fn(x)
88 | x = activation_fn(x)
89 | return tf.keras.models.Model(inputs=[inp], outputs=[x], name=name)
90 |
91 |
92 | def vlaai(
93 | nb_blocks=4,
94 | extractor_model=None,
95 | output_context_model=None,
96 | use_skip=True,
97 | input_channels=64,
98 | output_dim=1,
99 | name="vlaai",
100 | ):
101 | """Construct the VLAAI model.
102 |
103 | Parameters
104 | ----------
105 | nb_blocks: int
106 | Number of repeated blocks to use.
107 | extractor_model: Callable[[tf.Tensor], tf.Tensor]
108 | The extractor model to use.
109 | output_context_model: Callable[[tf.Tensor], tf.Tensor]
110 | The output context model to use.
111 | use_skip: bool
112 | Whether to use skip connections.
113 | input_channels: int
114 | Number of EEG channels in the input.
115 | output_dim: int
116 | Number of output dimensions.
117 | name: str
118 | Name of the model.
119 |
120 | Returns
121 | -------
122 | tf.keras.models.Model
123 | The VLAAI model.
124 | """
125 | if extractor_model is None:
126 | extractor_model = extractor()
127 | if output_context_model is None:
128 | output_context_model = output_context()
129 |
130 | eeg = tf.keras.layers.Input((None, input_channels))
131 |
132 | # If using skip connections: start with x set to zero
133 | if use_skip:
134 | x = tf.zeros_like(eeg)
135 | else:
136 | x = eeg
137 |
138 | # Iterate over the blocks
139 | for i in range(nb_blocks):
140 | if use_skip:
141 | x = extractor_model(eeg + x)
142 | else:
143 | x = extractor_model(x)
144 | x = tf.keras.layers.Dense(input_channels)(x)
145 | x = output_context_model(x)
146 |
147 | x = tf.keras.layers.Dense(output_dim)(x)
148 |
149 | return tf.keras.models.Model(inputs=[eeg], outputs=[x], name=name)
150 |
151 |
152 | def pearson_tf(y_true, y_pred, axis=1):
153 | """Pearson correlation function implemented in tensorflow.
154 |
155 | Parameters
156 | ----------
157 | y_true: tf.Tensor
158 | Ground truth labels. Shape is (batch_size, time_steps, n_features)
159 | y_pred: tf.Tensor
160 | Predicted labels. Shape is (batch_size, time_steps, n_features)
161 | axis: int
162 | Axis along which to compute the pearson correlation. Default is 1.
163 |
164 | Returns
165 | -------
166 | tf.Tensor
167 | Pearson correlation.
168 | Shape is (batch_size, 1, n_features) if axis is 1.
169 | """
170 | # Compute the mean of the true and predicted values
171 | y_true_mean = tf.reduce_mean(y_true, axis=axis, keepdims=True)
172 | y_pred_mean = tf.reduce_mean(y_pred, axis=axis, keepdims=True)
173 |
174 | # Compute the numerator and denominator of the pearson correlation
175 | numerator = tf.reduce_sum(
176 | (y_true - y_true_mean) * (y_pred - y_pred_mean),
177 | axis=axis,
178 | keepdims=True,
179 | )
180 | std_true = tf.reduce_sum(tf.square(y_true - y_true_mean), axis=axis, keepdims=True)
181 | std_pred = tf.reduce_sum(tf.square(y_pred - y_pred_mean), axis=axis, keepdims=True)
182 | denominator = tf.sqrt(std_true * std_pred)
183 |
184 | # Compute the pearson correlation
185 | return tf.math.divide_no_nan(numerator, denominator)
186 |
187 |
188 | @tf.function
189 | def pearson_loss(y_true, y_pred, axis=1):
190 | """Pearson loss function.
191 |
192 | Parameters
193 | ----------
194 | y_true: tf.Tensor
195 | True values. Shape is (batch_size, time_steps, n_features)
196 | y_pred: tf.Tensor
197 | Predicted values. Shape is (batch_size, time_steps, n_features)
198 |
199 | Returns
200 | -------
201 | tf.Tensor
202 | Pearson loss.
203 | Shape is (batch_size, 1, n_features)
204 | """
205 | return -pearson_tf(y_true, y_pred, axis=axis)
206 |
207 |
208 | @tf.function
209 | def pearson_metric(y_true, y_pred, axis=1):
210 | """Pearson metric function.
211 |
212 | Parameters
213 | ----------
214 | y_true: tf.Tensor
215 | True values. Shape is (batch_size, time_steps, n_features)
216 | y_pred: tf.Tensor
217 | Predicted values. Shape is (batch_size, time_steps, n_features)
218 |
219 | Returns
220 | -------
221 | tf.Tensor
222 | Pearson metric.
223 | Shape is (batch_size, 1, n_features)
224 | """
225 | return pearson_tf(y_true, y_pred, axis=axis)
226 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Code to preprocess the SparrKULee dataset
2 | =========================================
3 | This is the codebase to preprocess and validate the [SparrKULee](https://doi.org/10.48804/K3VSND) dataset.
4 | This codebase consist of two main parts:
5 | 1) preprocessing code, to preprocess the raw data into an easily usable format
6 | 2) technical validation code, to validate the technical quality of the dataset.
7 | This code is used to generate the results in the dataset paper and assumes that the preprocessing pipeline has been run
8 |
9 | Requirements
10 | ------------
11 |
12 | Python >= 3.7
13 |
14 | # General setup
15 |
16 | Steps to get a working setup:
17 |
18 | ## 1. Clone this repository and install the [requirements.txt](requirements.txt)
19 | ```bash
20 | # Clone this repository
21 | git clone https://github.com/exporl/auditory-eeg-dataset
22 |
23 | # Go to the root folder
24 | cd auditory-eeg-dataset
25 |
26 | # Optional: install a virtual environment
27 | python3 -m venv venv # Optional
28 | source venv/bin/activate # Optional
29 |
30 | # Install requirements.txt
31 | python3 -m install requirements.txt
32 | ```
33 |
34 | ## 2. [Download (parts of the) data](download_code/README.md)
35 |
36 | **We recommend accessing the dataset via our [website](https://homes.esat.kuleuven.be/~spchdata/corpora/auditory_eeg_data/), where we also share instructions to mount the dataset as an alternative to downloading.**
37 |
38 | The dataset is also hosted on the [KU Leuven RDR website](https://doi.org/10.48804/K3VSND) and is accessible through DOI ([https://doi.org/10.48804/K3VSND](https://doi.org/10.48804/K3VSND)).
39 |
40 | However, due to the dataset size/structure and the limitations of the UI of the KU Leuven RDR website, we also provide a [direct download link for the entire dataset in `.zip` format](https://rdr.kuleuven.be/api/access/dataset/:persistentId/?persistentId=doi:10.48804/K3VSND), a [onedrive repository containing then entire dataset split up into smaller files](https://kuleuven-my.sharepoint.com/:f:/g/personal/lies_bollens_kuleuven_be/EulH76nkcwxIuK--XJhLxKQBaX8_GgAX-rTKK7mskzmAZA?e=N6M5Ll) and a [tool to download (subsets of) the dataset robustly](download_code/README.md).
41 | For more information about the tool, see [download_code/README.md](download_code/README.md).
42 |
43 | Due to privacy concerns, not all data is publically available. Users requesting access to these files should send a mail to the authors (lies.bollens@kuleuven.be ; bernd.accou@kuleuven.be) , stating what they want to use the data for. Access will be granted to non-commercial users, complying to the CC-BY-NC-4.0 licence
44 |
45 | When you want to directly start from the preprocessed data (which is the output you will get when running the file
46 | [preprocessing_code/examples/auditory_eeg_dataset.py](preprocessing_code/examples/auditory_eeg_dataset.py)),
47 | you can download the **derivatives** folder. This folder contains all the necessary files to run the technical validation. This can also be downloaded using [the download tool](download_code/README.md) as follows:
48 |
49 | ```bash
50 | python3 download_code/download_script.py --subset preprocessed /path/to/local/folder
51 | ```
52 |
53 |
54 | ## 3. Adjust the [config.json](config.json) accordingly
55 |
56 | The [config.json](config.json) defining the folder names and structure for the data and derivatives folder.
57 | Adjust `dataset_folder` in the [config.json](config.json) file from `null` to the absolute path to the folder containing all data.
58 |
59 |
60 | OK, you should be all setup now!
61 |
62 | Preprocessing code
63 | ==================
64 |
65 | This repository uses the [brain_pipe package](https://github.com/exporl/brain_pipe)
66 | to preprocess the data. It is installed automatically when installing the [requirements.txt](requirements.txt).
67 | You are invited to contribute to the [brain_pipe package](https://github.com/exporl/brain_pipe) package, if you want to add new preprocessing steps.
68 | Documentation for the brain_pipe package can be found [here](https://exporl.github.io/brain_pipe/).
69 |
70 | Example usage
71 | -------------
72 |
73 | There are multiple ways to run the preprocessing pipeline, specified below.
74 |
75 | **Warning:** the script and the YAML file will create both Mel spectrograms and envelope representations of the stimuli.
76 | If this is not desired, you can comment out the appropriate lines.
77 |
78 | **Make sure your [brain_pipe](brain_pipe) version is up to date (>= 0.0.3)!**
79 | You can ensure this by running `pip3 install --upgrade brain_pipe` or `pip3 install --upgrade -r requirements.txt`.
80 |
81 | ### 1. Use the python script [preprocessing_code/sparrKULee.py](preprocessing_code/sparrKULee.py)
82 |
83 | ```bash
84 | python3 preprocessing_code/sparrKULee.py
85 | ```
86 |
87 | Different options (such as the number of parallel processes) can be specified from the command line.
88 | For more information, run :
89 |
90 | ```bash
91 | python3 preprocessing_code/sparrKULee.py --help.
92 | ```
93 |
94 | ### 2. Use the YAML file with the [brain_pipe](https://github.com/exporl/brain_pipe) CLI
95 |
96 | For this option, you will have to fill in the `--dataset_folder`, `--derivatives_folder`,
97 | `--preprocessed_stimuli_dir` and `--preprocessed_eeg_dir` with the values from the [config.json](config.json) file.
98 |
99 | ```bash
100 | brain_pipe preprocessing_code/sparrKULee.yaml --dataset_folder {/path/to/dataset} --derivatives_folder {derivatives_folder} --preprocessed_stimuli_dir {preprocessed_stimuli_dir} --preprocessed_eeg_dir {preprocessed_eeg_dir}
101 | ```
102 |
103 | Optionally, you could read the [config.json](config.json) file directly from the command line:
104 |
105 | ```bash
106 | brain_pipe preprocessing_code/sparrKULee.yaml $(python3 -c "import json; f=open('config.json'); d=json.load(f); f.close(); print(' '.join([f'--{x}={y}' for x,y in d.items() if 'split_folder' != x]))")
107 | ```
108 |
109 | For more information about the [brain_pipe](https://github.com/exporl/brain_pipe) CLI,
110 | see the appriopriate documentation for the [CLI](https://exporl.github.io/brain_pipe/cli.html) and [configuration files (e.g. YAML)](https://exporl.github.io/brain_pipe/configuration.html)
111 |
112 | Technical validation
113 | ====================
114 | This repository contains code to validate the preprocessed dataset using esthablished models.
115 | Running this code will yield the results summarized in the paper. (LINK TO PAPER)
116 |
117 | Prerequisites
118 | -------------
119 | The technical validation code assumes that the preprocessing pipeline has been run and that the derivatives folder is available.
120 | The derivatives folder contains the preprocessed data and the necessary files to run the technical validation code.
121 | Either download the derivatives folder directly from the online dataset or run the preprocessing pipeline yourself [preprocessing_code/examples/auditory_eeg_dataset.py](preprocessing_code/examples/auditory_eeg_dataset.py).
122 |
123 | Example usage
124 | -------------
125 |
126 | We have defined some ready-to-go experiments, to replicate the results summarized in the dataset paper.
127 | All these experiments use split (into training/validation/test partitions) and normalised data, which can be obtained by
128 | running [technical_validation/util/split_and_normalize.py](technical_validation/util/split_and_normalize.py).
129 |
130 | The experiment files live in the [technical_validation/experiments](technical_validation/experiments) folder. The training log,
131 | best model and evaluation results will be stored in a folder called
132 | `results_{experiment_name}`.
133 |
134 | To replicate the results summarized in the dataset paper, run the following experiments:
135 | ```bash
136 | # train the dilated convolutional model introduced by Accou et al.(https://doi.org/10.1088/1741-2552/ac33e9)
137 | match_mismatch_dilated_convolutional_model.py
138 |
139 | # train a simple linear backward model, reconstructing the envelope from EEG
140 | # using filtered data in different frequency bands
141 | # simple linear baseline model with Pearson correlation as a loss function, similar to the baseline model used in Accou et al (2022) (https://www.biorxiv.org/content/10.1101/2022.09.28.509945).
142 |
143 | regression_linear_backwards_model.py --highpass 0.5 -lowpass 30
144 | regression_linear_backwards_model.py --highpass 0.5 -lowpass 4
145 | regression_linear_backwards_model.py --highpass 4 -lowpass 8
146 | regression_linear_backwards_model.py --highpass 8 -lowpass 14
147 | regression_linear_backwards_model.py --highpass 14 -lowpass 30
148 |
149 | # train a simple linear forward model, predicting the EEG response from the envelope,
150 | # using filtered data in different frequency bands
151 | regression_linear_forward.py --highpass 0.5 -lowpass 30
152 | regression_linear_forward.py --highpass 0.5 -lowpass 4
153 |
154 | # train/evaluate the VLAAI model as proposed by Accou et al (2022) (https://www.biorxiv.org/content/10.1101/2022.09.28.509945). You can find a pre-trained model at VLAAI's github page (https://github.com/exporl/vlaai).
155 | regression_vlaai.py
156 | ```
157 |
158 | Finally, you can generate the plots as shown in the dataset paper by running the [technical_validation/util/plot_results.py](technical_validation/util/plot_results.py) script
159 |
160 |
--------------------------------------------------------------------------------
/download_code/__init__.py:
--------------------------------------------------------------------------------
1 | """Code to parse and download Dataverse datasets."""
2 | import datetime
3 | import hashlib
4 | import json
5 | import multiprocessing as mp
6 | import os
7 | import urllib.request
8 | import urllib.request
9 |
10 |
11 | class DataverseDownloader:
12 | """Download files from a Dataverse dataset."""
13 |
14 | def __init__(
15 | self,
16 | download_path,
17 | server,
18 | overwrite=False,
19 | check_md5=True,
20 | multiprocessing=-1,
21 | datetime_format="%Y-%m-%d %H:%M:%S",
22 | ):
23 | """Create a new DataverseDownloader.
24 |
25 | Parameters
26 | ----------
27 | download_path: str
28 | The path to download the files to. The path will be created if it does not
29 | exist.
30 | server: str
31 | The hostname of the server to download the files from.
32 | overwrite: bool
33 | Whether to overwrite existing files.
34 | check_md5: bool
35 | Whether to check the MD5 checksum of the downloaded files.
36 | multiprocessing: int
37 | The number of cores to use for multiprocessing. Set to 0 or 1 to disable
38 | multiprocessing. The default -1 uses all available cores.
39 | datetime_format: str
40 | The datetime format to use for printing the start and end time of the
41 | download.
42 | """
43 | self.download_path = download_path
44 | self.overwrite = overwrite
45 | self.check_md5 = check_md5
46 | self.multiprocessing = multiprocessing
47 | self.server = server
48 | self.datetime_format = datetime_format
49 | self._total = 0
50 |
51 | def get_url(self, file_id):
52 | """Get the download URL for a file ID.
53 |
54 | Parameters
55 | ----------
56 | file_id: str
57 | The file ID to get the download URL for.
58 |
59 | Returns
60 | -------
61 | str
62 | The download URL for the file ID.
63 | """
64 | return f"https://{self.server}/api/access/datafile/{file_id}?gbrecs=true"
65 |
66 | def __call__(self, file_id_mapping, filter_fn=lambda x, y: True):
67 | """Download the files from a file ID mapping.
68 |
69 | Parameters
70 | ----------
71 | file_id_mapping: Mapping[str, Mapping[str, Any]]
72 | A mapping from the path (relative to self.download_path) to save the file
73 | and another mapping containing at least 'id' as a key and 'md5' as a key
74 | (only necessary if self.check_md5 is True).
75 | filter_fn: Callable[[str, str], bool]
76 | A function that takes the path and file ID and returns whether to download
77 | the file.
78 |
79 | Returns
80 | -------
81 | List[str]
82 | A list of the downloaded files.
83 | """
84 | # Get the appropriate map function
85 | if self.multiprocessing not in [0, 1]:
86 | pool_count = (
87 | self.multiprocessing if self.multiprocessing > 0 else os.cpu_count()
88 | )
89 | pool = mp.Pool(pool_count)
90 | map_fn = pool.map
91 | else:
92 | map_fn = map
93 |
94 | # Filter the file ID mapping
95 | filtered_path_dict = self.filter(file_id_mapping, filter_fn=filter_fn)
96 |
97 | # Set total for logging
98 | self._total = len(filtered_path_dict)
99 |
100 | # Download the files
101 | print(
102 | f"Started downloading at "
103 | f"{datetime.datetime.now().strftime(self.datetime_format)}"
104 | )
105 | output = list(map_fn(self.download, enumerate(filtered_path_dict.items())))
106 | print(
107 | f"Finished downloading at "
108 | f"{datetime.datetime.now().strftime(self.datetime_format)}"
109 | )
110 |
111 | # Clean up multiprocessing
112 | if self.multiprocessing:
113 | pool.close()
114 | pool.join()
115 | return output
116 |
117 | def filter(self, file_id_mapping, filter_fn=lambda x, y: True):
118 | """Filter a file ID mapping.
119 |
120 | Parameters
121 | ----------
122 | file_id_mapping: Mapping[str, Mapping[str, Any]]
123 | A mapping from the path (relative to self.download_path) to save the file
124 | and another mapping containing at least 'id' as a key and 'md5' as a key
125 | (only necessary if self.check_md5 is True).
126 | filter_fn: Callable[[str, str], bool]
127 | A function that takes the path and file ID and returns whether to download
128 | the file.
129 |
130 | Returns
131 | -------
132 | Mapping[str, str]
133 | The filtered file ID mapping.
134 | """
135 | return {k: v for k, v in file_id_mapping.items() if filter_fn(k, v)}
136 |
137 | def compare_checksum(self, filepath, checksum):
138 | with open(filepath, 'rb') as fp:
139 | return hashlib.md5(fp.read()).hexdigest() == checksum
140 |
141 | def download(self, data):
142 | """Download a file.
143 |
144 | Parameters
145 | ----------
146 | data: Tuple[int, Tuple[str, str]]
147 | The index of the file and a tuple containing the relative path of the file,
148 | and a mapping containing at least a key 'id' for the file ID and a key
149 | 'md5' for the MD5 checksum (only necessary if self.check_md5 is True).
150 |
151 | Returns
152 | -------
153 | str
154 | The path to the downloaded file.
155 | """
156 | index, (path, mapping) = data
157 | filepath = os.path.join(self.download_path, path)
158 | print_preamble = (
159 | f"({index+1}/{self._total}) | "
160 | if self.multiprocessing not in [0, 1]
161 | else ""
162 | )
163 |
164 | if os.path.exists(filepath) and not self.overwrite:
165 | if self.check_md5:
166 | checksum_comparison = self.compare_checksum(filepath, mapping['md5'])
167 | if checksum_comparison:
168 | print(
169 | f"{print_preamble}{filepath} already exists and has the same "
170 | f"checksum, skipping (set overwrite=True to overwrite)"
171 | )
172 | return filepath
173 | else:
174 | print(
175 | f"{print_preamble}{filepath} already exists but has a"
176 | f" different checksum, overwriting"
177 | )
178 | else:
179 | print(
180 | f"{print_preamble}{filepath} already exists,"
181 | f"skipping (set overwrite=True to overwrite)"
182 | )
183 | return filepath
184 | os.makedirs(os.path.dirname(filepath), exist_ok=True)
185 | print(f"{print_preamble}Downloading {path} to {filepath}...", end=" ")
186 | urllib.request.urlretrieve(self.get_url(mapping['id']), filename=filepath)
187 |
188 | extra_msg = ""
189 | if self.check_md5:
190 | checksum_comparison = self.compare_checksum(filepath, mapping['md5'])
191 | if not checksum_comparison:
192 | raise ValueError(
193 | f"Checksum of {filepath} does not match the expected checksum."
194 | )
195 | else:
196 | extra_msg = " (checksum matches)"
197 | print(f"Done{extra_msg}")
198 | return filepath
199 |
200 |
201 | class DataverseParser:
202 | """Parse a Dataverse dataset for file IDs."""
203 |
204 | def __init__(self, server):
205 | """Create a new DataverseParser.
206 |
207 | Parameters
208 | ----------
209 | server: str
210 | The hostname of the server to download the files from.
211 | """
212 | self.server = server
213 |
214 | def get_url(self, dataset_id):
215 | """Get the URL to get the dataset information from.
216 |
217 | Parameters
218 | ----------
219 | dataset_id: str
220 | The DOI of the requested dataset.
221 |
222 | Returns
223 | -------
224 | str
225 | The URL to get the dataset information from.
226 | """
227 | return f"https://{self.server}/api/datasets/:persistentId/" \
228 | f"?persistentId={dataset_id}"
229 |
230 | def __call__(self, dataset_id):
231 | """Create a mapping between the relative path to the file and the file ID.
232 |
233 | Parameters
234 | ----------
235 | dataset_id: str
236 | The DOI of the requested dataset.
237 |
238 | Returns
239 | -------
240 | Mapping[str, str]
241 | A mapping between the relative path to the file in the dataset and the
242 | file ID.
243 | """
244 | # Get the dataset information
245 | url = self.get_url(dataset_id)
246 | print(f"Loading data from {url}")
247 | raw_info = urllib.request.urlopen(url).read().decode("utf-8")
248 | info = json.loads(raw_info)
249 | version_info = info["data"]["latestVersion"]
250 | print(
251 | f'Parsing data for version: '
252 | f'{version_info["versionNumber"]}.{version_info["versionMinorNumber"]}.'
253 | )
254 | # Parse the files
255 | file_id_mapping = {}
256 | for file_info in version_info["files"]:
257 | path = os.path.join(
258 | file_info.get("directoryLabel", ""), file_info["dataFile"]["filename"]
259 | )
260 | file_id_mapping[path] = {
261 | 'id': file_info["dataFile"]["id"],
262 | 'md5': file_info["dataFile"]["md5"],
263 | }
264 | return file_id_mapping
265 |
--------------------------------------------------------------------------------
/technical_validation/experiments/regression_linear_backward_model.py:
--------------------------------------------------------------------------------
1 | """Example experiment for a linear baseline method."""
2 |
3 |
4 | import sys
5 | import argparse
6 |
7 |
8 | import numpy as np
9 | import glob
10 | import json
11 | import logging
12 | import os
13 | import scipy.stats
14 |
15 | from technical_validation.util.dataset_generator import RegressionDataGenerator, create_tf_dataset
16 |
17 |
18 |
19 | def time_lag_matrix(input_, tmin, tmax):
20 | """Create a time-lag matrix from a 2D numpy array.
21 |
22 | Parameters
23 | ----------
24 | eeg: np.ndarray
25 | 2D numpy array with shape (n_samples, n_channels)
26 | num_lags: int
27 | Number of time lags to use.
28 |
29 | Returns
30 | -------
31 | np.ndarray
32 | 2D numpy array with shape (n_samples, n_channels* num_lags)
33 | """
34 | # Create a time-lag matrix
35 | numChannels = input_.shape[1]
36 |
37 | final_array = np.zeros((input_.shape[0], numChannels * (tmax - tmin)))
38 |
39 | for index, shift in enumerate(range(tmin, tmax)):
40 | # roll the array to the right
41 | shifted_data = np.roll(input_, -shift, axis=0)
42 | final_array[:, index * numChannels: (index + 1) * numChannels] = shifted_data
43 |
44 | if tmin < 0:
45 | return final_array[np.abs(tmin):-tmax+1, :]
46 | else:
47 | return final_array[:-tmax+1, :]
48 |
49 |
50 | def train_model_cov(cxx, cxy, ridge_param):
51 | return np.linalg.solve(cxx + ridge_param * np.eye(cxx.shape[0]), cxy)
52 |
53 |
54 | def evaluate_model(model, test_eeg, test_env):
55 | pred_env = np.matmul(test_eeg, model)
56 | return scipy.stats.pearsonr(pred_env[:, 0], test_env[:, 0])[0]
57 |
58 | def permutation_test(model, eeg, env, tmin, tmax, numPermutations=100):
59 | pred_env = np.matmul(eeg, model)
60 | corrs = []
61 | for permutation_index in range(numPermutations):
62 | print(f'Permutation {permutation_index+1:03d}\r', end='')
63 | random_shift = np.random.randint(tmax-tmin, env.shape[0] - (tmax-tmin))
64 | temp_env = np.roll(env, random_shift, axis=0)
65 | # temp_eeg = temp_eeg[random_shift:, :]
66 | # temp_pred_eeg = pred_eeg[:-random_shift, :]
67 |
68 | corrs.append(scipy.stats.pearsonr(temp_env[:, 0], pred_env[:, 0])[0])
69 |
70 | print()
71 | return np.array(corrs)
72 |
73 |
74 | def crossval_over_recordings(all_data, tmin, tmax, ridge_param):
75 | # Cross validation loop to determine the optimal ridge parameter
76 | fold_scores = []
77 | for fold in range(len(all_data)):
78 | print(f'fold {fold}...')
79 |
80 | # train_folds
81 | train_eeg_folds = [x[0] for i, x in enumerate(all_data) if i != fold]
82 | train_env_folds = [x[1] for i, x in enumerate(all_data) if i != fold]
83 |
84 | # test_fold
85 | test_eeg_fold = [x[0] for i, x in enumerate(all_data) if i == fold][0]
86 | test_env_fold = [x[1] for i, x in enumerate(all_data) if i == fold][0]
87 |
88 | # create the model
89 | # closed-form solution,
90 | train_eegs = [time_lag_matrix(eeg, tmin, tmax) for eeg in train_eeg_folds]
91 | cxx = np.sum([np.matmul(x.T, x) for x in train_eegs], axis=0)
92 | cxy = np.sum([np.matmul(x.T, y[:-(tmax - tmin) + 1, :]) for x, y in zip(train_eegs, train_env_folds)], axis=0)
93 |
94 | if not isinstance(ridge_param, (int, float)):
95 | ridge_scores = []
96 | for lambd in ridge_param:
97 | model = train_model_cov(cxx, cxy, lambd)
98 | # evaluate the model on the test set
99 | score = evaluate_model(model,time_lag_matrix(test_eeg_fold, tmin, tmax), test_env_fold[:-(tmax - tmin) + 1, :])
100 | ridge_scores.append(score)
101 | fold_scores.append(ridge_scores)
102 | else:
103 | model = train_model_cov(cxx, cxy, ridge_param)
104 | score = evaluate_model(model, time_lag_matrix(test_eeg_fold, tmin, tmax), test_env_fold[:-(tmax - tmin) + 1, :])
105 | fold_scores.append(score)
106 | return fold_scores
107 |
108 | def training_loop(subject, data_folder, features, highpass, lowpass, tmin, tmax, ridge_param):
109 | print(f"Training model for subject {subject}")
110 |
111 | train_files = [x for x in glob.glob(os.path.join(data_folder, "train_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features and subject in x]
112 | train_files = [x for x in train_files if 'audiobook_15_' not in x]
113 | train_generator = RegressionDataGenerator(train_files, high_pass_freq=highpass, low_pass_freq=lowpass)
114 | all_data = [x for x in train_generator]
115 |
116 | # Leave-one-out cross validation based on number of recordings
117 | # Done to determine the optimal ridge parameter
118 | numFolds = len(all_data)
119 |
120 | fold_scores = crossval_over_recordings(all_data, tmin, tmax, ridge_param)
121 |
122 | if not isinstance(ridge_param, (int, float)):
123 | fold_scores = np.array(fold_scores)
124 | # Take the average across channels and folds to obtain 1 correlation value
125 | # per lambda value
126 | fold_scores = fold_scores.mean()
127 | best_ridge = ridge_param[np.argmax(fold_scores)]
128 | print(f"Best lambda: {best_ridge}")
129 | else:
130 | best_ridge = ridge_param
131 |
132 |
133 | # Actual training of the model on all training folds
134 | train_eegs = [time_lag_matrix(x[0].numpy(), tmin, tmax) for x in all_data]
135 | train_envs = [x[1].numpy() for x in all_data]
136 | cxx = np.sum([np.matmul(x.T, x) for x in train_eegs], axis=0)
137 | cxy = np.sum([np.matmul(x.T, y[:-(tmax-tmin)+1, :]) for x, y in zip(train_eegs, train_envs)], axis=0)
138 | model = train_model_cov(cxx, cxy, best_ridge)
139 |
140 | # # evaluate the model on the test set
141 | test_files = [x for x in glob.glob(os.path.join(data_folder, "test_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features and subject in x]
142 | test_files = [x for x in test_files if 'audiobook_15_' not in x]
143 | test_generator = RegressionDataGenerator(test_files, high_pass_freq=highpass, low_pass_freq=lowpass, return_filenames=True)
144 |
145 | # calculate pearson correlation, per test segment
146 | # and average over all test segments
147 | test_info = {'subject': subject, 'stim_filename':[], 'score': [], 'null_distr': [], 'ridge_param': best_ridge, 'model_weights': model.tolist(), 'highpass':highpass, 'lowpass':lowpass, 'numFolds':numFolds}
148 | print(f"Testing model on test data... ({numFolds})")
149 | for test_filenames, test_seg in test_generator:
150 | test_eeg = test_seg[0]
151 | test_env = test_seg[1]
152 | stim_filename = os.path.basename(test_filenames[1]).split("_-_")[-2]
153 |
154 | test_eeg = time_lag_matrix(test_eeg, tmin, tmax)
155 | # shorten eeg to match the length of the env
156 | test_env = test_env[:-(tmax-tmin)+1, :]
157 | # predict
158 | pearson_scores = evaluate_model(model, test_eeg, test_env)
159 |
160 | test_info['stim_filename'].append(stim_filename)
161 | test_info['score'].append(pearson_scores)
162 | # null distribution
163 | null_distr = permutation_test(model, test_eeg, test_env, tmin, tmax).tolist()
164 | test_info['null_distr'].append(null_distr)
165 | test_info['mean_score'] = np.mean(test_info['score'])
166 | test_info['95_percentile'] = np.percentile(test_info['null_distr'], 95)
167 |
168 | return test_info
169 |
170 |
171 |
172 | if __name__ == "__main__":
173 | # Parameters
174 |
175 | # frequency band chosen for the experiment
176 | # delta (0.5 -4 )
177 | # theta (4 - 8)
178 | # alpha (8 - 14)
179 | # beta (14 - 30)
180 | # broadband (0.5 - 32)
181 | parser = argparse.ArgumentParser()
182 | parser.add_argument('--highpass', type=float, default=None)
183 | parser.add_argument('--lowpass', type=float, default=4)
184 |
185 | args = parser.parse_args()
186 | highpass = args.highpass
187 | lowpass = args.lowpass
188 |
189 | for highpass, lowpass in [(None, 4), (4, 8), (8, 14), (14, 30), (None, None)]:
190 | numChannels = 64
191 | tmin = -np.round(0.1*64).astype(int) # -100 ms
192 | tmax = np.round(0.4*64).astype(int) # 400 ms
193 | ridge_param = [10**x for x in range(-6, 7, 2)]
194 | overwrite = False
195 |
196 | results_filename = 'eval_filter_{subject}_{tmin}_{tmax}_{highpass}_{lowpass}.json'
197 |
198 | # Get the path to the config gile
199 | experiments_folder = os.path.dirname(__file__)
200 | main_folder = os.path.dirname(os.path.dirname(experiments_folder))
201 | config_path = os.path.join(main_folder, 'config.json')
202 |
203 | # Load the config
204 | with open(config_path) as fp:
205 | config = json.load(fp)
206 |
207 | # Provide the path of the dataset
208 | # which is split already to train, val, test
209 |
210 | data_folder = os.path.join(config["dataset_folder"], config["derivatives_folder"], config["split_folder"])
211 | features = ["envelope", "eeg"]
212 |
213 | # Create a directory to store (intermediate) results
214 | results_folder = os.path.join(experiments_folder, "results_linear_backward")
215 | os.makedirs(results_folder, exist_ok=True)
216 |
217 |
218 | # get all the subjects
219 | all_files = glob.glob(os.path.join(data_folder, "train_-_*"))
220 | subjects = list(set([os.path.basename(x).split("_-_")[1] for x in all_files]))
221 |
222 | evaluation_all_subs = {}
223 | chance_level_all_subs = {}
224 |
225 | # train one model per subject
226 | for subject in subjects:
227 | save_path = os.path.join(results_folder, results_filename.format(subject=subject, tmin=tmin, tmax=tmax, highpass=highpass, lowpass=lowpass))
228 | if not os.path.exists(save_path) or overwrite:
229 | result = training_loop(subject, data_folder, features, highpass, lowpass, tmin, tmax, ridge_param)
230 |
231 | # save the results
232 | with open(save_path, 'w') as fp:
233 | json.dump(result, fp)
234 | else:
235 | print(f"Results for {subject} already exist, skipping...")
236 |
237 |
238 |
239 |
240 |
241 |
242 |
--------------------------------------------------------------------------------
/technical_validation/experiments/regression_linear_forward_model.py:
--------------------------------------------------------------------------------
1 | """Example experiment for a linear baseline method."""
2 |
3 |
4 | import sys
5 | import argparse
6 |
7 |
8 | import numpy as np
9 | import glob
10 | import json
11 | import logging
12 | import os
13 | import scipy.stats
14 |
15 | from technical_validation.util.dataset_generator import RegressionDataGenerator, create_tf_dataset
16 |
17 |
18 |
19 | def time_lag_matrix(input_, tmin, tmax):
20 | """Create a time-lag matrix from a 2D numpy array.
21 |
22 | Parameters
23 | ----------
24 | eeg: np.ndarray
25 | 2D numpy array with shape (n_samples, n_channels)
26 | num_lags: int
27 | Number of time lags to use.
28 |
29 | Returns
30 | -------
31 | np.ndarray
32 | 2D numpy array with shape (n_samples, n_channels* num_lags)
33 | """
34 | # Create a time-lag matrix
35 | numChannels = input_.shape[1]
36 |
37 | final_array = np.zeros((input_.shape[0], numChannels * (tmax - tmin)))
38 |
39 | for index, shift in enumerate(range(tmin, tmax)):
40 | # roll the array to the right
41 | shifted_data = np.roll(input_, -shift, axis=0)
42 | final_array[:, index * numChannels: (index + 1) * numChannels] = shifted_data
43 |
44 | if tmin < 0:
45 | return final_array[np.abs(tmin):-tmax+1, :]
46 | else:
47 | return final_array[:-tmax+1, :]
48 |
49 |
50 | def train_model_cov(cxx, cxy, ridge_param):
51 | return np.linalg.solve(cxx + ridge_param * np.eye(cxx.shape[0]), cxy)
52 |
53 | def evaluate_model(model, test_env, test_eeg):
54 | pred_eeg = np.matmul(test_env, model)
55 | channel_scores = []
56 | for channel in range(test_eeg.shape[1]):
57 | score = scipy.stats.pearsonr(pred_eeg[:, channel], test_eeg[:, channel])[0]
58 | channel_scores.append(score.tolist())
59 | return channel_scores
60 |
61 | def permutation_test(model, eeg, env, tmin, tmax, numPermutations=100):
62 | pred_eeg = np.matmul(env, model)
63 | corrs = []
64 | for permutation_index in range(numPermutations):
65 | print(f'Permutation {permutation_index+1:03d}\r', end='')
66 | random_shift = np.random.randint(tmax-tmin, eeg.shape[0] - (tmax-tmin))
67 | temp_eeg = np.roll(eeg, random_shift, axis=0)
68 | # temp_eeg = temp_eeg[random_shift:, :]
69 | # temp_pred_eeg = pred_eeg[:-random_shift, :]
70 | temp_pred_eeg = pred_eeg
71 | channel_corrs = []
72 | for channel in range(eeg.shape[1]):
73 | channel_corrs.append(scipy.stats.pearsonr(temp_eeg[:, channel], temp_pred_eeg[:, channel])[0])
74 | corrs.append(channel_corrs)
75 | print()
76 | return np.array(corrs)
77 |
78 | def crossval_over_recordings(all_data, tmin, tmax, ridge_param):
79 | # Cross validation loop to determine the optimal ridge parameter
80 | fold_scores = []
81 | for fold in range(len(all_data)):
82 | print(f'fold {fold}...')
83 |
84 | # train_folds
85 | train_eeg_folds = [x[0] for i, x in enumerate(all_data) if i != fold]
86 | train_env_folds = [x[1] for i, x in enumerate(all_data) if i != fold]
87 |
88 | # test_fold
89 | test_eeg_fold = [x[0] for i, x in enumerate(all_data) if i == fold][0]
90 | test_env_fold = [x[1] for i, x in enumerate(all_data) if i == fold][0]
91 |
92 | # create the model
93 | # closed-form solution,
94 | train_envs = [time_lag_matrix(env, tmin, tmax) for env in train_env_folds]
95 | cxx = np.sum([np.matmul(x.T, x) for x in train_envs], axis=0)
96 | cxy = np.sum([np.matmul(x.T, y[:-(tmax - tmin) + 1, :]) for x, y in
97 | zip(train_envs, train_eeg_folds)], axis=0)
98 |
99 | if not isinstance(ridge_param, (int, float)):
100 | ridge_scores = []
101 | for lambd in ridge_param:
102 | model = train_model_cov(cxx, cxy, lambd)
103 | # evaluate the model on the test set
104 | score = evaluate_model(model,
105 | time_lag_matrix(test_env_fold, tmin, tmax),
106 | test_eeg_fold[:-(tmax - tmin) + 1, :])
107 | ridge_scores.append(score)
108 | fold_scores.append(ridge_scores)
109 | else:
110 | model = train_model_cov(cxx, cxy, ridge_param)
111 | score = evaluate_model(model, time_lag_matrix(test_env_fold, tmin, tmax),
112 | test_eeg_fold[:-(tmax - tmin) + 1, :])
113 | fold_scores.append(score)
114 | return fold_scores
115 |
116 | def training_loop(subject, data_folder, features, highpass, lowpass, tmin, tmax, ridge_param):
117 | print(f"Training model for subject {subject}")
118 |
119 | train_files = [x for x in glob.glob(os.path.join(data_folder, "train_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features and subject in x]
120 | train_files = [x for x in train_files if 'audiobook_15_' not in x]
121 | train_generator = RegressionDataGenerator(train_files, high_pass_freq=highpass, low_pass_freq=lowpass)
122 | all_data = [x for x in train_generator]
123 |
124 | # Leave-one-out cross validation based on number of recordings
125 | # Done to determine the optimal ridge parameter
126 | numFolds = len(all_data)
127 |
128 | fold_scores = crossval_over_recordings(all_data, tmin, tmax, ridge_param)
129 |
130 | if not isinstance(ridge_param, (int, float)):
131 | fold_scores = np.array(fold_scores)
132 | # Take the average across channels and folds to obtain 1 correlation value
133 | # per lambda value
134 | fold_scores = fold_scores.mean(axis=2).mean(axis=0)
135 | best_ridge = ridge_param[np.argmax(fold_scores)]
136 | print(f"Best lambda: {best_ridge}")
137 | else:
138 | best_ridge = ridge_param
139 |
140 |
141 | # Actual training of the model on all training folds
142 | train_eegs = [x[0] for x in all_data]
143 | train_envs = [time_lag_matrix(x[1], tmin, tmax) for x in all_data]
144 | cxx = np.sum([np.matmul(x.T, x) for x in train_envs], axis=0)
145 | cxy = np.sum([np.matmul(x.T, y[:-(tmax-tmin)+1, :]) for x, y in zip(train_envs, train_eegs)], axis=0)
146 | model = train_model_cov(cxx, cxy, best_ridge)
147 |
148 | # # evaluate the model on the test set
149 | test_files = [x for x in glob.glob(os.path.join(data_folder, "test_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features and subject in x]
150 | test_files = [x for x in test_files if 'audiobook_15_' not in x]
151 | test_generator = RegressionDataGenerator(test_files, high_pass_freq=highpass, low_pass_freq=lowpass, return_filenames=True)
152 |
153 | # calculate pearson correlation, per test segment
154 | # and average over all test segments
155 | test_info = {'subject': subject, 'stim_filename':[], 'score': [], 'null_distr': [], 'ridge_param': best_ridge, 'model_weights': model.tolist(), 'highpass':highpass, 'lowpass':lowpass, 'numFolds':numFolds}
156 | print(f"Testing model on test data... ({numFolds})")
157 | for test_filenames, test_seg in test_generator:
158 | test_eeg = test_seg[0]
159 | test_env = test_seg[1]
160 | stim_filename = os.path.basename(test_filenames[1]).split("_-_")[-2]
161 |
162 | test_env = time_lag_matrix(test_env, tmin, tmax)
163 | # shorten eeg to match the length of the env
164 | test_eeg = test_eeg[:-(tmax-tmin)+1, :]
165 | # predict
166 | pearson_scores = evaluate_model(model, test_env, test_eeg)
167 |
168 | test_info['stim_filename'].append(stim_filename)
169 | test_info['score'].append(pearson_scores)
170 | # null distribution
171 | null_distr = permutation_test(model, test_eeg, test_env, tmin, tmax).tolist()
172 | test_info['null_distr'].append(null_distr)
173 | test_info['mean_score_per_channel'] = np.mean(test_info['score'], axis=0).tolist()
174 | test_info['mean_score'] = np.mean(test_info['score'])
175 | null_distr = np.reshape(test_info['null_distr'], (-1, 64))
176 | test_info['95_percentile_per_channel'] = np.percentile(null_distr, 95, axis=0).tolist()
177 | test_info['95_percentile'] = np.percentile(test_info['null_distr'], 95)
178 |
179 | return test_info
180 |
181 |
182 |
183 | if __name__ == "__main__":
184 | # Parameters
185 |
186 | # frequency band chosen for the experiment
187 | # delta (0.5 -4 )
188 | # theta (4 - 8)
189 | # alpha (8 - 14)
190 | # beta (14 - 30)
191 | # broadband (0.5 - 32)
192 | parser = argparse.ArgumentParser()
193 | parser.add_argument('--highpass', type=float, default=None)
194 | parser.add_argument('--lowpass', type=float, default=4)
195 |
196 | args = parser.parse_args()
197 | highpass = args.highpass
198 | lowpass = args.lowpass
199 |
200 | for highpass, lowpass in [(None, 4), (4, 8), (8, 14), (14, 30), (None, None)]:
201 |
202 |
203 | numChannels = 64
204 | tmin = -np.round(0.1*64).astype(int) # -100 ms
205 | tmax = np.round(0.4*64).astype(int) # 400 ms
206 | ridge_param = [10**x for x in range(-6, 7, 2)]
207 | overwrite = False
208 |
209 | results_filename = 'eval_filter_{subject}_{tmin}_{tmax}_{highpass}_{lowpass}.json'
210 |
211 | # Get the path to the config gile
212 | experiments_folder = os.path.dirname(__file__)
213 | main_folder = os.path.dirname(os.path.dirname(experiments_folder))
214 | config_path = os.path.join(main_folder, 'config.json')
215 |
216 | # Load the config
217 | with open(config_path) as fp:
218 | config = json.load(fp)
219 |
220 | # Provide the path of the dataset
221 | # which is split already to train, val, test
222 |
223 | data_folder = os.path.join(config["dataset_folder"], config["derivatives_folder"], config["split_folder"])
224 | features = ["envelope", "eeg"]
225 |
226 | # Create a directory to store (intermediate) results
227 | results_folder = os.path.join(experiments_folder, "results_linear_forward")
228 | os.makedirs(results_folder, exist_ok=True)
229 |
230 |
231 | # get all the subjects
232 | all_files = glob.glob(os.path.join(data_folder, "train_-_*"))
233 | subjects = list(set([os.path.basename(x).split("_-_")[1] for x in all_files]))
234 |
235 | evaluation_all_subs = {}
236 | chance_level_all_subs = {}
237 |
238 | # train one model per subject
239 | for subject in subjects:
240 | save_path = os.path.join(results_folder, results_filename.format(subject=subject, tmin=tmin, tmax=tmax, highpass=highpass, lowpass=lowpass))
241 | if not os.path.exists(save_path) or overwrite:
242 | result = training_loop(subject, data_folder, features, highpass, lowpass, tmin, tmax, ridge_param)
243 |
244 | # save the results
245 | with open(save_path, 'w') as fp:
246 | json.dump(result, fp)
247 | else:
248 | print(f"Results for {subject} already exist, skipping...")
249 |
250 |
251 |
252 |
253 |
254 |
255 |
--------------------------------------------------------------------------------
/technical_validation/util/plot_results.py:
--------------------------------------------------------------------------------
1 | # import seaborn as sns
2 | import glob
3 | import json
4 | import os
5 |
6 | import matplotlib.pyplot as plt
7 | import mne
8 | import numpy as np
9 | import pandas as pd
10 | import scipy.stats
11 | import seaborn as sns
12 |
13 | # generate plots from all the different results and plsave them in the figures folder
14 |
15 | # load the results
16 | base_results_folder = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'experiments')
17 | os.makedirs(os.path.join(base_results_folder, 'figures'), exist_ok=True)
18 | plot_dilation = True
19 | plot_linear_backward = True
20 | plot_linear_forward = True
21 | plot_vlaai = True
22 |
23 | freq_bands = {
24 | 'Delta [0.5-4]': (0.5, 4.0),
25 | 'Theta [4-8]': (4, 8.0),
26 | 'Alpha [8-14]': (8, 14.0),
27 | 'Beta [14-30]': (14, 30.0),
28 | 'Broadband [0.5-32]': (0.5, 32.0),
29 | }
30 |
31 | if plot_dilation:
32 | # dilation model, match mismatch results
33 | # plot boxplot of the results per window length
34 | # load evaluation results for all window lengths
35 |
36 | files = glob.glob(os.path.join(base_results_folder, "results_dilated_convolutional_model/eval_*.json"))
37 | # sort the files
38 | files.sort()
39 |
40 | # create dict to save all results per sub
41 | results = []
42 | windows = []
43 | for f in files:
44 | # load the results
45 | with open(f, "rb") as ff:
46 | res = json.load(ff)
47 | #loop over res and get accuracy in a list
48 | acc = []
49 | for sub, sub_res in res.items():
50 |
51 | if 'acc' in sub_res:
52 | acc.append(sub_res['acc']*100)
53 |
54 | results.append(acc)
55 |
56 | # get the window length
57 | windows.append(int(f.split("_")[-1].split(".")[0]))
58 |
59 | # sort windows and results according to windows
60 | windows, results = zip(*sorted(zip(windows, results)))
61 | # convert windows to seconds
62 | windows = ['%d' %int(w/64) for w in windows]
63 |
64 | #boxplot of the results
65 | plt.boxplot(results, labels=windows)
66 | plt.xlabel("Window length (s)")
67 | plt.ylabel("Accuracy (%)")
68 | # plt.title("Accuracy of dilation model, per window length")
69 | plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_dilated_conv.pdf"))
70 |
71 | if plot_linear_backward:
72 | ## linear backward regression plots
73 | # load the results
74 | results_files_glob = os.path.join(base_results_folder, "results_linear_backward", "eval*.json")
75 | results_files = glob.glob(results_files_glob)
76 | all_results = []
77 |
78 | for result_file in results_files:
79 | with open(result_file, "r") as f:
80 | data = json.load(f)
81 | for index, stim in enumerate(data['stim_filename']):
82 | percentile = np.percentile(data['null_distr'][index], 95)
83 | score = data['score'][index]
84 | highpass = data['highpass'] if data['highpass'] is not None else 0.5
85 | lowpass = data['lowpass'] if data['lowpass'] is not None else 32.0
86 |
87 | all_results.append([data['subject'], stim, score, highpass, lowpass, percentile, score > percentile, data['null_distr'][index]])
88 |
89 | df = pd.DataFrame(all_results, columns=['subject', 'stim', 'score', 'highpass', 'lowpass', 'percentile', 'significant', 'null_distr'])
90 |
91 | print('Confirming that we found neural tracking for each subject')
92 | nb_subjects = len(pd.unique(df['subject']))
93 | nb_significant = df.groupby('subject').agg({'significant': 'any'}).sum()
94 | print('Found {} subjects, {} of which had at least one significant result'.format(nb_subjects, nb_significant))
95 |
96 |
97 | subject_stories_group = df.groupby(['subject', 'stim']).agg({'significant': 'any'})
98 | nb_recordings = len(subject_stories_group)
99 | nb_significant_recordings = subject_stories_group.sum()
100 | print("Found {} recordings, {} of which were significant".format(nb_recordings, nb_significant_recordings))
101 | non_significant_recordings_series = df.groupby(['subject', 'stim']).agg({'significant': 'any'})['significant'] == False
102 | non_significant_recordings = non_significant_recordings_series[non_significant_recordings_series].index
103 |
104 | # Table of non-significant results
105 | print("Non-significant results:")
106 | for subject, stimulus in non_significant_recordings:
107 | print("{} & {} \\\\".format(subject, stimulus))
108 |
109 | # plot the results
110 | ## General frequency plot
111 | values_for_boxplot = []
112 | signficance_level = []
113 | names_for_boxplot = []
114 | for band_name, (highpass, lowpass) in freq_bands.items():
115 |
116 | selected_df = df[(df['highpass'] == highpass) & (df['lowpass'] == lowpass)]
117 | values_for_boxplot.append(selected_df.groupby('subject').agg('mean')['score'].tolist())
118 | signficance_level.append(np.percentile(selected_df['null_distr'].tolist(), 95))
119 |
120 | plt.figure(figsize=(8, 5))
121 | # plt.boxplot(values_for_boxplot, labels=freq_bands.keys())
122 | temp_df = pd.DataFrame(values_for_boxplot, index=freq_bands.keys(), columns=range(85)).T
123 | sns.violinplot(data=temp_df)
124 |
125 | xs = [np.linspace(x - 0.5, x + 0.5, 200) for x in range(5)]
126 | ys = [[signficance_level[i]]*200 for i in range(5)]
127 |
128 | plt.plot(np.reshape(xs, (-1,)), np.reshape(ys, (-1,)), '--', color='black')
129 | plt.ylabel('Pearson correlation')
130 | plt.xlabel('Frequency band')
131 | #plt.grid(True)
132 | plt.xlim(-0.5, len(values_for_boxplot)-0.5)
133 | plt.title('Linear decoder performance across frequency bands')
134 | # plt.show()
135 | plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_linear_backward_frequency.pdf"))
136 | #plt.close()
137 |
138 | ## plot the results per stimulus
139 | values_for_boxplot = []
140 | names_for_boxplot = []
141 | def sort_key_fn(x):
142 | split = x.split('_')
143 | number = ord(split[0][0])*1e6+ int(split[1])*1e3
144 | # artifact audiobook
145 | if len(split) > 2 and split[2].isdigit():
146 | number += int(split[2])
147 | return number
148 |
149 | fig, (ax0, ax1) = plt.subplots(1,2, width_ratios=[0.2, 0.8], figsize=(15,6), sharey=True)
150 |
151 | for stimulus in sorted(pd.unique(df['stim']), key=sort_key_fn):
152 | selected_df = df[(df['stim'] == stimulus) &(df['highpass'] == 0.5) & (df['lowpass'] == 4.0)]
153 | values_for_boxplot.append(selected_df.groupby('subject').agg('mean')['score'].tolist())
154 | names_for_boxplot.append(stimulus + ' ({})'.format(len(selected_df)))
155 | # plt.figure(figsize=(11,6))
156 | ax1.boxplot(values_for_boxplot, labels=names_for_boxplot)
157 | #ax1.set_ylabel('Pearson correlation')
158 | # ax1.set_xlabel('Stimulus (Number of recordings)')
159 | plt.xticks(rotation=90)
160 |
161 | # plt.grid(True)
162 | ax1.set_title('Across stimuli')
163 | # plt.tight_layout()
164 | # plt.show()
165 | # plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_linear_backward_stimuli.pdf"))
166 | # plt.close()
167 |
168 |
169 | ## plot the results per stimulus type
170 | values_for_boxplot = []
171 | names_for_boxplot = []
172 | stim_type_selectors = {
173 | 'Audiobook': df['stim'].str.startswith('audiobook_'),
174 | 'Podcast': df['stim'].str.startswith('podcast_'),
175 | }
176 | test_data = []
177 | for stimulus_type, stimulus_selector in stim_type_selectors.items():
178 | selected_df = df[(stimulus_selector) &(df['highpass'] == 0.5) & (df['lowpass'] == 4.0) & (~(df['stim'].str.endswith('artefact'))) & (~(df['stim'].str.endswith('shifted'))) & (~(df['stim'].str.endswith('audiobook_1_1')))& (~(df['stim'].str.endswith('audiobook_1_2')))]
179 | values_for_boxplot.append(selected_df.groupby('subject').agg('mean')['score'].tolist())
180 | names_for_boxplot.append(stimulus_type + ' ({})'.format(len(selected_df)))
181 | test_data += [selected_df['score'].tolist()]
182 |
183 | print("MannWhitneyU test: {}, medians: {}, {}".format(scipy.stats.mannwhitneyu(test_data[0], test_data[1]),np.median(test_data[0]), np.median(test_data[1])))
184 | # plt.figure()
185 | ax0.boxplot(values_for_boxplot, labels=names_for_boxplot)
186 | ax0.set_ylabel('Pearson correlation')
187 | # ax0.set_xlabel('\nStimulus type (Number of recordings)')
188 | # plt.xticks()
189 |
190 | #plt.grid(True)
191 | ax0.set_title('Across stimuli types')
192 | fig.suptitle('Linear decoder performance')
193 | fig.supxlabel('Stimulus [type] (Number of recordings)')
194 |
195 | plt.tight_layout()
196 |
197 | # plt.show()
198 | # plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_linear_backward_stimulus_type.pdf"))
199 | plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_linear_backward_stimulus_combined.pdf"))
200 |
201 | # plt.close()
202 |
203 |
204 | if plot_linear_forward:
205 | results_files_glob = os.path.join(base_results_folder, "results_linear_forward",
206 | "eval*.json")
207 | results_files = glob.glob(results_files_glob)
208 | all_results = []
209 |
210 | for result_file in results_files:
211 | with open(result_file, "r") as f:
212 | data = json.load(f)
213 | for index, stim in enumerate(data['stim_filename']):
214 | percentile = np.percentile(data['null_distr'][index], 95)
215 | score = data['score'][index]
216 | null_distr = data['null_distr'][index]
217 | highpass = data['highpass'] if data['highpass'] is not None else 0.5
218 | lowpass = data['lowpass'] if data['lowpass'] is not None else 32.0
219 |
220 | all_results.append([data['subject'], stim, score, null_distr, highpass, lowpass, percentile, score > percentile, data['null_distr'][index]])
221 |
222 | df = pd.DataFrame(all_results, columns=['subject', 'stim', 'scores_per_channel', 'null_distr', 'highpass', 'lowpass', 'percentile', 'significant', 'null_distr'])
223 |
224 | montage = mne.channels.make_standard_montage('biosemi64')
225 | sfreq = 64
226 | info = mne.create_info(ch_names=montage.ch_names, sfreq=sfreq, ch_types='eeg').set_montage(montage)
227 |
228 | fig, axes = plt.subplots(3, 5, figsize=(16, 12))
229 | temp_df = df.copy()
230 | temp_df['filterband'] = temp_df['highpass'].astype(str) + '-' + temp_df['lowpass'].astype(str)
231 | scores = temp_df.groupby(['filterband']).agg({'scores_per_channel': lambda x: list(x)})
232 | all_scores = np.mean(scores['scores_per_channel'].tolist(), axis=1)
233 | max_coef = np.max(all_scores)
234 | min_coef = np.min(all_scores)
235 |
236 | stim_type_selectors = {
237 | 'All Stimuli': df['stim'].str.startswith(''),
238 | 'Audiobook': df['stim'].str.startswith('audiobook_'),
239 | 'Podcast': df['stim'].str.startswith('podcast_'),
240 |
241 | }
242 | for index, (stim_type, selector) in enumerate(stim_type_selectors.items()):
243 |
244 | axes[index][0].set_ylabel(stim_type)
245 |
246 | for index2, (band_name, (highpass, lowpass)) in enumerate(freq_bands.items()):
247 | ax = axes[index][index2]
248 | selected_df = df[(df['highpass'] == highpass) & (df['lowpass'] == lowpass) & (selector)]
249 | scores_per_subject = selected_df.groupby('subject').agg({'scores_per_channel': lambda x: np.mean(np.stack(x,axis=1), axis=1)})
250 | mean_scores = np.mean(scores_per_subject.to_numpy(), axis=0).tolist()[0]
251 | # percentile = np.percentile(np.concatenate(selected_df['null_distr'].tolist(), axis=0), 95, axis=0)
252 | # plot the topoplot
253 | im , cn = mne.viz.plot_topomap(mean_scores, pos=info, axes=ax ,show=False, cmap='Reds', vlim=(min_coef,max_coef))
254 | mne.viz.tight_layout()
255 | ax.set_title(f"{band_name} Hz")
256 |
257 | # cbar_ax = fig.add_axes([0.95, 0.15, 0.05, 0.7])
258 |
259 |
260 |
261 | #plt.colorbar(im, cax=axes[5], label='Pearson correlation') #cax=cbar_ax,
262 | fig.suptitle('Forward model performance across stimuli types and frequency bands',y=0.94, fontsize=18)
263 | fig.tight_layout()
264 |
265 | fig.subplots_adjust(right=0.90)
266 | cbar_ax = fig.add_axes([0.92, 0.11, 0.02, 0.77])
267 | fig.colorbar(im, cax=cbar_ax, label='Pearson correlation')
268 | plt.savefig(os.path.join(base_results_folder, 'figures', f"topoplot_linear_forward.pdf"))
269 | #plt.show()
270 | plt.close(fig)
271 |
272 |
273 |
274 |
275 |
276 |
277 | if plot_vlaai:
278 |
279 | ## vlaai results
280 | # load the results
281 | results_files = glob.glob(os.path.join(base_results_folder, "results_vlaai/eval*.json"))
282 |
283 | results = {}
284 | for f in results_files:
285 |
286 | with open(f, "rb") as ff:
287 | res = json.load(ff)
288 | #loop over res and get accuracy in a list
289 | acc = []
290 | subs = []
291 | for sub, sub_res in res.items():
292 | if 'pearson_metric' in sub_res:
293 | acc.append(sub_res['pearson_metric'])
294 | subs.append(sub)
295 |
296 | results = acc
297 |
298 | # plot the results
299 | plt.figure()
300 | plt.boxplot(results, labels=['vlaai'])
301 | plt.xlabel("Model")
302 | plt.ylabel("Correlation")
303 | # plt.title("Correlation, per subject")
304 | plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_vlaai.pdf"))
305 |
306 |
307 |
--------------------------------------------------------------------------------
/technical_validation/util/dataset_generator.py:
--------------------------------------------------------------------------------
1 | """Code for the dataset_generator for task1."""
2 | import itertools
3 | import os
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | import scipy.signal
8 |
9 | @tf.function
10 | def default_batch_equalizer_fn(*args):
11 | """Batch equalizer.
12 | Prepares the inputs for a model to be trained in
13 | match-mismatch task. It makes sure that match_env
14 | and mismatch_env are equally presented as a first
15 | envelope in match-mismatch task.
16 |
17 | Parameters
18 | ----------
19 | args : Sequence[tf.Tensor]
20 | List of tensors representing feature data
21 |
22 | Returns
23 | -------
24 | Tuple[Tuple[tf.Tensor], tf.Tensor]
25 | Tuple of the EEG/speech features serving as the input to the model and
26 | the labels for the match/mismatch task
27 |
28 | Notes
29 | -----
30 | This function will also double the batch size. E.g. if the batch size of
31 | the elements in each of the args was 32, the output features will have
32 | a batch size of 64.
33 | """
34 | eeg = args[0]
35 | new_eeg = tf.concat([eeg, eeg], axis=0)
36 | all_features = [new_eeg]
37 | for match, mismatch in zip(args[1::2], args[2::2]):
38 | stimulus_feature1 = tf.concat([match, mismatch], axis=0)
39 | stimulus_feature2 = tf.concat([mismatch, match], axis=0)
40 | all_features += [stimulus_feature1, stimulus_feature2]
41 | labels = tf.concat(
42 | [
43 | tf.tile(tf.constant([[0]]), [tf.shape(eeg)[0], 1]),
44 | tf.tile(tf.constant([[1]]), [tf.shape(eeg)[0], 1]),
45 | ],
46 | axis=0,
47 | )
48 |
49 | # print(new_eeg.shape, env1.shape, env2.shape, labels.shape)
50 | return tuple(all_features), labels
51 |
52 |
53 | def create_tf_dataset(
54 | data_generator,
55 | window_length,
56 | batch_equalizer_fn=None,
57 | hop_length=64,
58 | batch_size=64,
59 | data_types=(tf.float32, tf.float32, tf.float32),
60 | feature_dims=(64, 1, 1),
61 | eeg_first = True # if eeg_first is True, the first feature is eeg, otherwise the first feature is the stimulus
62 | ):
63 | """Creates a tf.data.Dataset.
64 |
65 | This will be used to create a dataset generator that will
66 | pass windowed data to a model in both tasks.
67 |
68 | Parameters
69 | ---------
70 | data_generator: DataGenerator
71 | A data generator.
72 | window_length: int
73 | Length of the decision window in samples.
74 | batch_equalizer_fn: Callable
75 | Function that will be applied on the data after batching (using
76 | the `map` method from tf.data.Dataset). In the match/mismatch task,
77 | this function creates the imposter segments and labels.
78 | hop_length: int
79 | Hop length between two consecutive decision windows.
80 | batch_size: Optional[int]
81 | If not None, specifies the batch size. In the match/mismatch task,
82 | this amount will be doubled by the default_batch_equalizer_fn
83 | data_types: Union[Sequence[tf.dtype], tf.dtype]
84 | The data types that the individual features of data_generator should
85 | be cast to. If you only specify a single datatype, it will be chosen
86 | for all EEG/speech features.
87 |
88 | Returns
89 | -------
90 | tf.data.Dataset
91 | A Dataset object that generates data to train/evaluate models
92 | efficiently
93 | """
94 | # create tf dataset from generator
95 | dataset = tf.data.Dataset.from_generator(
96 | data_generator,
97 | output_signature=tuple(
98 | tf.TensorSpec(shape=(None, x), dtype=data_types[index])
99 | for index, x in enumerate(feature_dims)
100 | ),
101 | )
102 | # if forward model, swap the order of the features
103 | if not eeg_first:
104 | dataset = dataset.map(lambda eeg, speech: (speech, eeg))
105 |
106 | # window dataset
107 | dataset = dataset.map(
108 | lambda *args: [
109 | tf.signal.frame(arg, window_length, hop_length, axis=0)
110 | for arg in args
111 | ]
112 | )
113 |
114 | # batch data
115 | dataset = dataset.interleave(
116 | lambda *args: tf.data.Dataset.from_tensor_slices(args),
117 | cycle_length=4,
118 | block_length=16,
119 | )
120 | if batch_size is not None:
121 | dataset = dataset.batch(batch_size, drop_remainder=True)
122 |
123 | if batch_equalizer_fn is not None:
124 | # Create the labels and make sure classes are balanced
125 | dataset = dataset.map(batch_equalizer_fn)
126 |
127 | return dataset
128 |
129 |
130 |
131 | class MatchMismatchDataGenerator:
132 | """Generate data for the Match/Mismatch task."""
133 |
134 | def __init__(
135 | self,
136 | files,
137 | window_length,
138 | spacing
139 | ):
140 | """Initialize the DataGenerator.
141 |
142 | Parameters
143 | ----------
144 | files: Sequence[Union[str, pathlib.Path]]
145 | Files to load.
146 | window_length: int
147 | Length of the decision window.
148 | spacing: int
149 | Spacing between matched and mismatched samples
150 | """
151 | self.window_length = window_length
152 | self.files = self.group_recordings(files)
153 | self.spacing = spacing
154 |
155 | def group_recordings(self, files):
156 | """Group recordings and corresponding stimuli.
157 |
158 | Parameters
159 | ----------
160 | files : Sequence[Union[str, pathlib.Path]]
161 | List of filepaths to preprocessed and split EEG and speech features
162 |
163 | Returns
164 | -------
165 | list
166 | Files grouped by the self.group_key_fn and subsequently sorted
167 | by the self.feature_sort_fn.
168 | """
169 | new_files = []
170 | grouped = itertools.groupby(sorted(files), lambda x: "_-_".join(os.path.basename(x).split("_-_")[:3]))
171 | for recording_name, feature_paths in grouped:
172 | new_files += [sorted(feature_paths, key=lambda x: "0" if x == "eeg" else x)]
173 | return new_files
174 |
175 | def __len__(self):
176 | return len(self.files)
177 |
178 | def __getitem__(self, recording_index):
179 | """Get data for a certain recording.
180 |
181 | Parameters
182 | ----------
183 | recording_index: int
184 | Index of the recording in this dataset
185 |
186 | Returns
187 | -------
188 | Union[Tuple[tf.Tensor,...], Tuple[np.ndarray,...]]
189 | The features corresponding to the recording_index recording
190 | """
191 | data = []
192 | for feature in self.files[recording_index]:
193 | data += [np.load(feature).astype(np.float32)]
194 | data = self.prepare_data(data)
195 | return tuple(tf.constant(x) for x in data)
196 |
197 |
198 | def __call__(self):
199 | """Load data for the next recording.
200 |
201 | Yields
202 | -------
203 | Union[Tuple[tf.Tensor,...], Tuple[np.ndarray,...]]
204 | The features corresponding to the recording_index recording
205 | """
206 | for idx in range(self.__len__()):
207 | yield self.__getitem__(idx)
208 |
209 | if idx == self.__len__() - 1:
210 | self.on_epoch_end()
211 |
212 | def on_epoch_end(self):
213 | """Change state at the end of an epoch."""
214 | np.random.shuffle(self.files)
215 |
216 | def prepare_data(self, data):
217 | """Creates mismatch (imposter) envelope.
218 |
219 | Parameters
220 | ----------
221 | data: Sequence[numpy.ndarray]
222 | Data to create an imposter for.
223 |
224 | Returns
225 | -------
226 | tuple (numpy.ndarray, numpy.ndarray, numpy.ndarray, ...)
227 | (EEG, matched stimulus feature, mismatched stimulus feature, ...).
228 | """
229 | eeg = data[0]
230 | new_length = eeg.shape[0] - self.window_length - self.spacing
231 | resulting_data = [eeg[:new_length, ...]]
232 | for stimulus_feature in data[1:]:
233 | match_feature = stimulus_feature[:new_length, ...]
234 | mismatch_feature = stimulus_feature[
235 | self.spacing + self.window_length:, ...
236 | ]
237 | resulting_data += [match_feature, mismatch_feature]
238 | return resulting_data
239 |
240 |
241 |
242 | class RegressionDataGenerator:
243 | """Generate data for the regression task."""
244 |
245 | def __init__(
246 | self,
247 | files,
248 | window_length= None,
249 | high_pass_freq= None,
250 | low_pass_freq= None,
251 | return_filenames=False,
252 | ):
253 | """Initialize the DataGenerator.
254 |
255 | Parameters
256 | ----------
257 | files: Sequence[Union[str, pathlib.Path]]
258 | Files to load.
259 | window_length: int
260 | Length of the decision window.
261 | """
262 | self.files = self.group_recordings(files)
263 | self.return_filenames = return_filenames
264 | self.high_pass_freq = high_pass_freq
265 | self.low_pass_freq = low_pass_freq
266 |
267 | if self.high_pass_freq and self.low_pass_freq:
268 | self.filter_ = scipy.signal.butter(N= 1,
269 | Wn =[self.high_pass_freq, self.low_pass_freq],
270 | btype= "bandpass",
271 | fs=64,
272 | output="sos")
273 | if self.high_pass_freq and not self.low_pass_freq:
274 | self.filter_ = scipy.signal.butter(N= 1,
275 | Wn = self.high_pass_freq,
276 | btype= "highpass",
277 | fs=64,
278 | output="sos")
279 | if not self.high_pass_freq and self.low_pass_freq:
280 | self.filter_ = scipy.signal.butter(N= 1,
281 | Wn = self.low_pass_freq,
282 | btype= "lowpass",
283 | fs=64,
284 | output="sos")
285 | if not self.high_pass_freq and not self.low_pass_freq:
286 | self.filter_ = None
287 |
288 |
289 | def group_recordings(self, files):
290 | """Group recordings and corresponding stimuli.
291 |
292 | Parameters
293 | ----------
294 | files : Sequence[Union[str, pathlib.Path]]
295 | List of filepaths to preprocessed and split EEG and speech features
296 |
297 | Returns
298 | -------
299 | list
300 | Files grouped by the self.group_key_fn and subsequently sorted
301 | by the self.feature_sort_fn.
302 | """
303 | new_files = []
304 | grouped = itertools.groupby(sorted(files), lambda x: "_-_".join(os.path.basename(x).split("_-_")[:3]))
305 | for recording_name, feature_paths in grouped:
306 | new_files += [sorted(feature_paths, key=lambda x: "0" if x == "eeg" else x)]
307 | return new_files
308 |
309 | def __len__(self):
310 | return len(self.files)
311 |
312 | def __getitem__(self, recording_index):
313 | """Get data for a certain recording.
314 |
315 | Parameters
316 | ----------
317 | recording_index: int
318 | Index of the recording in this dataset
319 |
320 | Returns
321 | -------
322 | Union[Tuple[tf.Tensor,...], Tuple[np.ndarray,...]]
323 | The features corresponding to the recording_index recording
324 | """
325 | data = []
326 | for feature in self.files[recording_index]:
327 | data += [np.load(feature).astype(np.float32)]
328 |
329 | data = self.prepare_data(data)
330 | if self.return_filenames:
331 | return self.files[recording_index], tuple(tf.constant(x) for x in data)
332 | else:
333 | return tuple(tf.constant(x) for x in data)
334 |
335 |
336 | def __call__(self):
337 | """Load data for the next recording.
338 |
339 | Yields
340 | -------
341 | Union[Tuple[tf.Tensor,...], Tuple[np.ndarray,...]]
342 | The features corresponding to the recording_index recording
343 | """
344 | for idx in range(self.__len__()):
345 | yield self.__getitem__(idx)
346 |
347 | if idx == self.__len__() - 1:
348 | self.on_epoch_end()
349 |
350 | def on_epoch_end(self):
351 | """Change state at the end of an epoch."""
352 | np.random.shuffle(self.files)
353 |
354 | def prepare_data(self, data):
355 | """ If specified, filter the data between highpass and lowpass
356 | :param data: list of numpy arrays, eeg and envelope
357 | :return: filtered data
358 |
359 | """
360 |
361 | if self.filter_ is not None:
362 | resulting_data = []
363 | # assuming time is the first dimension and channels the second
364 | resulting_data.append(scipy.signal.sosfiltfilt(self.filter_, data[0], axis=0))
365 |
366 | for stimulus_feature in data[1:]:
367 | resulting_data.append(scipy.signal.sosfiltfilt(self.filter_, stimulus_feature, axis=0))
368 |
369 | else:
370 | resulting_data = data
371 |
372 | return resulting_data
373 |
374 |
375 |
--------------------------------------------------------------------------------
/preprocessing_code/sparrKULee.py:
--------------------------------------------------------------------------------
1 | """Run the default preprocessing pipeline on sparrKULee."""
2 | import argparse
3 | import datetime
4 | import gzip
5 | import json
6 | import logging
7 | import os
8 | from typing import Any, Dict, Sequence
9 |
10 | import librosa
11 | import numpy as np
12 | import scipy.signal
13 | import scipy.signal.windows
14 |
15 | from brain_pipe.dataloaders.path import GlobLoader
16 | from brain_pipe.pipeline.default import DefaultPipeline
17 | from brain_pipe.preprocessing.brain.artifact import (
18 | InterpolateArtifacts,
19 | ArtifactRemovalMWF,
20 | )
21 | from brain_pipe.preprocessing.brain.eeg.biosemi import (
22 | biosemi_trigger_processing_fn,
23 | )
24 | from brain_pipe.preprocessing.brain.eeg.load import LoadEEGNumpy
25 | from brain_pipe.preprocessing.brain.epochs import SplitEpochs
26 | from brain_pipe.preprocessing.brain.link import (
27 | LinkStimulusToBrainResponse,
28 | BIDSStimulusInfoExtractor,
29 | )
30 | from brain_pipe.preprocessing.brain.rereference import CommonAverageRereference
31 | from brain_pipe.preprocessing.brain.trigger import (
32 | AlignPeriodicBlockTriggers,
33 | )
34 | from brain_pipe.preprocessing.filter import SosFiltFilt
35 | from brain_pipe.preprocessing.resample import ResamplePoly
36 | from brain_pipe.preprocessing.stimulus.audio.spectrogram import LibrosaMelSpectrogram
37 | from brain_pipe.preprocessing.stimulus.audio.envelope import GammatoneEnvelope
38 | from brain_pipe.preprocessing.stimulus.load import LoadStimuli
39 | from brain_pipe.runner.default import DefaultRunner
40 | from brain_pipe.save.default import DefaultSave
41 | from brain_pipe.utils.log import default_logging, DefaultFormatter
42 | from brain_pipe.utils.path import BIDSStimulusGrouper
43 |
44 |
45 | class BIDSAPRStimulusInfoExtractor(BIDSStimulusInfoExtractor):
46 | """Extract BIDS compliant stimulus information from an .apr file."""
47 |
48 | def __call__(self, brain_dict: Dict[str, Any]):
49 | """Extract BIDS compliant stimulus information from an events.tsv file.
50 |
51 | Parameters
52 | ----------
53 | brain_dict: Dict[str, Any]
54 | The data dict containing the brain data path.
55 |
56 | Returns
57 | -------
58 | Sequence[Dict[str, Any]]
59 | The extracted event information. Each dict contains the information
60 | of one row in the events.tsv file
61 | """
62 | event_info = super().__call__(brain_dict)
63 | # Find the apr file
64 | path = brain_dict[self.brain_path_key]
65 | apr_path = "_".join(path.split("_")[:-1]) + "_eeg.apr"
66 | # Read apr file
67 | apr_data = self.get_apr_data(apr_path)
68 | # Add apr data to event info
69 | for e_i in event_info:
70 | e_i.update(apr_data)
71 | return event_info
72 |
73 | def get_apr_data(self, apr_path: str):
74 | """Get the SNR from an .apr file.
75 |
76 | Parameters
77 | ----------
78 | apr_path: str
79 | Path to the .apr file.
80 |
81 | Returns
82 | -------
83 | Dict[str, Any]
84 | The SNR.
85 | """
86 | import xml.etree.ElementTree as ET
87 |
88 | apr_data = {}
89 | tree = ET.parse(apr_path)
90 | root = tree.getroot()
91 |
92 | # Get SNR
93 | interactive_elements = root.findall(".//interactive/entry")
94 | for element in interactive_elements:
95 | description_element = element.find("description")
96 | if description_element.text == "SNR":
97 | apr_data["snr"] = element.find("new_value").text
98 | if "snr" not in apr_data:
99 | logging.warning(f"Could not find SNR in {apr_path}.")
100 | apr_data["snr"] = 100.0
101 | return apr_data
102 |
103 |
104 | def default_librosa_load_fn(path):
105 | """Load a stimulus using librosa.
106 |
107 | Parameters
108 | ----------
109 | path: str
110 | Path to the audio file.
111 |
112 | Returns
113 | -------
114 | Dict[str, Any]
115 | The data and the sampling rate.
116 | """
117 | data, sr = librosa.load(path, sr=None)
118 | return {"data": data, "sr": sr}
119 |
120 |
121 | def default_npz_load_fn(path):
122 | """Load a stimulus from a .npz file.
123 |
124 | Parameters
125 | ----------
126 | path: str
127 | Path to the .npz file.
128 |
129 | Returns
130 | -------
131 | Dict[str, Any]
132 | The data and the sampling rate.
133 | """
134 | np_data = np.load(path)
135 | return {
136 | "data": np_data["audio"],
137 | "sr": np_data["fs"],
138 | }
139 |
140 |
141 | DEFAULT_LOAD_FNS = {
142 | ".wav": default_librosa_load_fn,
143 | ".mp3": default_librosa_load_fn,
144 | ".npz": default_npz_load_fn,
145 | }
146 |
147 |
148 | def temp_stimulus_load_fn(path):
149 | """Load stimuli from (Gzipped) files.
150 |
151 | Parameters
152 | ----------
153 | path: str
154 | Path to the stimulus file.
155 |
156 | Returns
157 | -------
158 | Dict[str, Any]
159 | Dict containing the data under the key "data" and the sampling rate
160 | under the key "sr".
161 | """
162 | if path.endswith(".gz"):
163 | with gzip.open(path, "rb") as f_in:
164 | data = dict(np.load(f_in))
165 | return {
166 | "data": data["audio"],
167 | "sr": data["fs"],
168 | }
169 |
170 | extension = "." + ".".join(path.split(".")[1:])
171 | if extension not in DEFAULT_LOAD_FNS:
172 | raise ValueError(
173 | f"Can't find a load function for extension {extension}. "
174 | f"Available extensions are {str(list(DEFAULT_LOAD_FNS.keys()))}."
175 | )
176 | load_fn = DEFAULT_LOAD_FNS[extension]
177 | return load_fn(path)
178 |
179 |
180 | def bids_filename_fn(data_dict, feature_name, set_name=None):
181 | """Default function to generate a filename for the data.
182 |
183 | Parameters
184 | ----------
185 | data_dict: Dict[str, Any]
186 | The data dict containing the data to save.
187 | feature_name: str
188 | The name of the feature.
189 | set_name: Optional[str]
190 | The name of the set. If no set name is given, the set name is not
191 | included in the filename.
192 |
193 | Returns
194 | -------
195 | str
196 | The filename.
197 | """
198 |
199 | filename = os.path.basename(data_dict["data_path"]).split("_eeg")[0]
200 |
201 | subject = filename.split("_")[0]
202 | session = filename.split("_")[1]
203 | filename += f"_desc-preproc-audio-{os.path.basename(data_dict.get('stimulus_path', '*.')).split('.')[0]}_{feature_name}"
204 |
205 | if set_name is not None:
206 | filename += f"_set-{set_name}"
207 |
208 | return os.path.join(subject, session, filename + ".npy")
209 |
210 |
211 | class SparrKULeeSpectrogramKwargs:
212 | """Default function to generate the kwargs for the librosa spectrogram."""
213 |
214 | def __init__(
215 | self,
216 | stimulus_sr_key="stimulus_sr",
217 | target_fs=64,
218 | hop_length=None,
219 | win_length_sec=0.025,
220 | n_fft=None,
221 | window_fn=None,
222 | n_mels=28,
223 | fmin=-4.2735,
224 | fmax=5444,
225 | power=1.0,
226 | center=False,
227 | norm=None,
228 | htk=True,
229 | ):
230 | self.stimulus_sr_key = stimulus_sr_key
231 | self.target_fs = target_fs
232 | self.hop_length = hop_length
233 | self.win_length_sec = win_length_sec
234 | self.n_fft = n_fft
235 | self.window_fn = window_fn
236 | if window_fn is None:
237 | self.window_fn = scipy.signal.windows.hamming
238 | self.n_mels = n_mels
239 | self.fmin = fmin
240 | self.fmax = fmax
241 | self.power = power
242 | self.center = center
243 | self.norm = norm
244 | self.htk = htk
245 |
246 | def __call__(self, data_dict):
247 | """Default function to generate the kwargs for the librosa spectrogram.
248 |
249 | Parameters
250 | ----------
251 | data_dict: Dict[str, Any]
252 | The data dict containing the data to save.
253 |
254 | Returns
255 | -------
256 | Dict[str, Any]
257 | The kwargs for the librosa spectrogram.
258 |
259 | Notes
260 | -----
261 | Code was based on the code for the 2023 Auditory EEG Challenge code:
262 | https://github.com/exporl/auditory-eeg-challenge-2023-code/blob/main/
263 | task1_match_mismatch/util/mel_spectrogram.py
264 | """
265 | fs = data_dict[self.stimulus_sr_key]
266 | result = {
267 | "fmin": self.fmin,
268 | "fmax": self.fmax,
269 | "n_mels": self.n_mels,
270 | "power": self.power,
271 | "center": self.center,
272 | "norm": self.norm,
273 | "htk": self.htk,
274 | }
275 |
276 | result["hop_length"] = self.hop_length
277 | if self.hop_length is None:
278 | result["hop_length"] = int((1 / self.target_fs) * fs)
279 |
280 | result["win_length"] = self.win_length_sec
281 | if self.win_length_sec is not None:
282 | result["win_length"] = int(self.win_length_sec * fs)
283 |
284 | result["n_fft"] = self.n_fft
285 | if self.n_fft is None:
286 | result["n_fft"] = int(2 ** np.ceil(np.log2(result["win_length"])))
287 |
288 | result["window"] = self.window_fn(result["win_length"])
289 | return result
290 |
291 |
292 | def run_preprocessing_pipeline(
293 | root_dir,
294 | preprocessed_stimuli_dir,
295 | preprocessed_eeg_dir,
296 | nb_processes=-1,
297 | overwrite=False,
298 | log_path="sparrKULee.log",
299 | ):
300 | """Construct and run the preprocessing on SparrKULee.
301 |
302 | Parameters
303 | ----------
304 | root_dir: str
305 | The root directory of the dataset.
306 | preprocessed_stimuli_dir:
307 | The directory where the preprocessed stimuli should be saved.
308 | preprocessed_eeg_dir:
309 | The directory where the preprocessed EEG should be saved.
310 | nb_processes: int
311 | The number of processes to use. If -1, the number of processes is
312 | automatically determined.
313 | overwrite: bool
314 | Whether to overwrite existing files.
315 | log_path: str
316 | The path to the log file.
317 | """
318 | #########
319 | # PATHS #
320 | #########
321 | os.makedirs(preprocessed_eeg_dir, exist_ok=True)
322 | os.makedirs(preprocessed_stimuli_dir, exist_ok=True)
323 |
324 | ###########
325 | # LOGGING #
326 | ###########
327 | handler = logging.FileHandler(log_path)
328 | handler.setLevel(logging.DEBUG)
329 | handler.setFormatter(DefaultFormatter())
330 | default_logging(handlers=[handler])
331 |
332 | ################
333 | # DATA LOADING #
334 | ################
335 | logging.info("Retrieving BIDS layout...")
336 | data_loader = GlobLoader(
337 | [os.path.join(root_dir, "sub-*", "*", "eeg", "*.bdf*")],
338 | filter_fns=[lambda x: "restingState" not in x],
339 | key="data_path",
340 | )
341 |
342 | #########
343 | # STEPS #
344 | #########
345 |
346 | stimulus_steps = DefaultPipeline(
347 | steps=[
348 | LoadStimuli(load_fn=temp_stimulus_load_fn),
349 | GammatoneEnvelope(),
350 | LibrosaMelSpectrogram(
351 | power_factor=0.6, librosa_kwargs=SparrKULeeSpectrogramKwargs()
352 | ),
353 | ResamplePoly(64, "envelope_data", "stimulus_sr"),
354 | # Comment out the next line if you don't want to use mel
355 | DefaultSave(
356 | preprocessed_stimuli_dir,
357 | to_save={
358 | "envelope": "envelope_data",
359 | # Comment out the next line if you don't want to use mel
360 | "mel": "spectrogram_data",
361 | },
362 | overwrite=overwrite,
363 | ),
364 | DefaultSave(preprocessed_stimuli_dir, overwrite=overwrite),
365 | ],
366 | on_error=DefaultPipeline.RAISE,
367 | )
368 |
369 | eeg_steps = [
370 | LinkStimulusToBrainResponse(
371 | stimulus_data=stimulus_steps,
372 | extract_stimuli_information_fn=BIDSAPRStimulusInfoExtractor(),
373 | grouper=BIDSStimulusGrouper(
374 | bids_root=root_dir,
375 | mapping={"stim_file": "stimulus_path", "trigger_file": "trigger_path"},
376 | subfolders=["stimuli", "eeg"],
377 | ),
378 | ),
379 | LoadEEGNumpy(unit_multiplier=1e6, channels_to_select=list(range(64))),
380 | SosFiltFilt(
381 | scipy.signal.butter(1, 0.5, "highpass", fs=1024, output="sos"),
382 | emulate_matlab=True,
383 | axis=1,
384 | ),
385 | InterpolateArtifacts(),
386 | AlignPeriodicBlockTriggers(biosemi_trigger_processing_fn),
387 | SplitEpochs(),
388 | ArtifactRemovalMWF(),
389 | CommonAverageRereference(),
390 | ResamplePoly(64, axis=1),
391 | DefaultSave(
392 | preprocessed_eeg_dir,
393 | {"eeg": "data"},
394 | overwrite=overwrite,
395 | clear_output=True,
396 | filename_fn=bids_filename_fn,
397 | ),
398 | ]
399 |
400 | #########################
401 | # RUNNING THE PIPELINE #
402 | #########################
403 |
404 | logging.info("Starting with the EEG preprocessing")
405 | logging.info("===================================")
406 |
407 | # Create data_dicts for the EEG files
408 | # Create the EEG pipeline
409 | eeg_pipeline = DefaultPipeline(steps=eeg_steps)
410 |
411 | DefaultRunner(
412 | nb_processes=nb_processes,
413 | logging_config=lambda: None,
414 | ).run(
415 | [(data_loader, eeg_pipeline)],
416 | )
417 |
418 |
419 | if __name__ == "__main__":
420 | # Code for the sparrKULee dataset
421 | # (https://rdr.kuleuven.be/dataset.xhtml?persistentId=doi:10.48804/K3VSND)
422 | #
423 | # A slight adaption of this code can also be found in the spaRRKULee repository:
424 | # https://github.com/exporl/auditory-eeg-dataset
425 | # under preprocessing_code/sparrKULee.py
426 | # Load the config
427 | with open("config.json", "r") as f:
428 | config = json.load(f)
429 |
430 | # Set the correct paths as default arguments
431 | dataset_folder = config["dataset_folder"]
432 | derivatives_folder = os.path.join(dataset_folder, config["derivatives_folder"])
433 | preprocessed_stimuli_folder = os.path.join(
434 | derivatives_folder, config["preprocessed_stimuli_folder"]
435 | )
436 | preprocessed_eeg_folder = os.path.join(
437 | derivatives_folder, config["preprocessed_eeg_folder"]
438 | )
439 | # Set the default log folder
440 | default_log_folder = os.path.dirname(os.path.abspath(__file__))
441 |
442 | # Parse arguments from the command line
443 | parser = argparse.ArgumentParser(description="Preprocess the auditory EEG dataset")
444 | parser.add_argument(
445 | "--nb_processes",
446 | type=int,
447 | default=-1,
448 | help="Number of processes to use for the preprocessing. "
449 | "The default is to use all available cores (-1).",
450 | )
451 | parser.add_argument(
452 | "--overwrite", action="store_true", help="Overwrite existing files"
453 | )
454 | parser.add_argument(
455 | "--log_path", type=str, default=os.path.join(
456 | default_log_folder,
457 | "sparrKULee_{datetime}.log"
458 | )
459 | )
460 | parser.add_argument(
461 | "--dataset_folder",
462 | type=str,
463 | default=dataset_folder,
464 | help="Path to the folder where the dataset is downloaded",
465 | )
466 | parser.add_argument(
467 | "--preprocessed_stimuli_path",
468 | type=str,
469 | default=preprocessed_stimuli_folder,
470 | help="Path to the folder where the preprocessed stimuli will be saved",
471 | )
472 | parser.add_argument(
473 | "--preprocessed_eeg_path",
474 | type=str,
475 | default=preprocessed_eeg_folder,
476 | help="Path to the folder where the preprocessed EEG will be saved",
477 | )
478 | args = parser.parse_args()
479 |
480 | # Run the preprocessing pipeline
481 | run_preprocessing_pipeline(
482 | args.dataset_folder,
483 | args.preprocessed_stimuli_path,
484 | args.preprocessed_eeg_path,
485 | args.nb_processes,
486 | args.overwrite,
487 | args.log_path.format(
488 | datetime=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
489 | ),
490 | )
491 |
--------------------------------------------------------------------------------