├── requirements.txt ├── .idea ├── encodings.xml ├── vcs.xml ├── .gitignore ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── auditory-eeg-dataset.iml ├── config.json ├── technical_validation ├── experiments │ ├── create_environment.sh │ ├── regression_vlaai.py │ ├── match_mismatch_dilated_convolutional_model.py │ ├── regression_linear_backward_model.py │ └── regression_linear_forward_model.py ├── models │ ├── linear.py │ ├── dilated_convolutional_model.py │ └── vlaai.py └── util │ ├── split_and_normalize.py │ ├── plot_results.py │ └── dataset_generator.py ├── download_code ├── README.md ├── download_script.py └── __init__.py ├── preprocessing_code ├── sparrKULee.yaml └── sparrKULee.py ├── .gitignore └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | brian2 2 | brian2hears 3 | numpy 4 | scipy 5 | mne 6 | librosa 7 | pandas 8 | pybids 9 | seaborn 10 | brain_pipe>=0.0.3 -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_folder": "null", 3 | "derivatives_folder": "derivatives", 4 | "preprocessed_eeg_folder": "preprocessed_eeg", 5 | "preprocessed_stimuli_folder": "preprocessed_stimuli", 6 | "split_folder": "split_data" 7 | } 8 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /technical_validation/experiments/create_environment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #create the necesary environment variables 4 | source ~/.bashrc 5 | source ~/.local/bin/load_conda_environment.sh bids_preprocessing 6 | 7 | printenv 8 | 9 | # Fix for unrecognized 'CUDA0' CUDA_VISIBLE_DEVICES 10 | export CUDA_VISIBLE_DEVICES=$(echo $CUDA_VISIBLE_DEVICES | sed -e 's/CUDA//g' | sed -e 's/\://g') 11 | export PYTHONPATH="/esat/spchtemp/scratch/baccou/auditory-eeg-dataset/" 12 | echo "$@" 13 | 14 | #run the original 15 | $@ 16 | 17 | -------------------------------------------------------------------------------- /.idea/auditory-eeg-dataset.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /download_code/README.md: -------------------------------------------------------------------------------- 1 | Tools/info to download the dataset 2 | ================================== 3 | 4 | Code to make downloading of the dataset more convenient. 5 | 6 | You can download the full dataset in `.zip` format [from the KU Leuven RDR site by clicking this link.](https://rdr.kuleuven.be/api/access/dataset/:persistentId/?persistentId=doi:10.48804/K3VSND) 7 | 8 | If you want more control over what you download and don't want to fully restart when downloading fails, you can use [download_script.py](./download_script.py) to download the dataset. 9 | 10 | The code in [download_script.py](./download_script.py) is meant to be used to download [the auditory EEG dataset from the RDR website](https://rdr.kuleuven.be/dataset.xhtml?persistentId=doi:10.48804/K3VSND), but can (possibly) be used to download other datasets from DataVerse servers. 11 | 12 | # Usage 13 | 14 | ```bash 15 | python download_script.py --help 16 | ``` 17 | 18 | ```text 19 | usage: download_script.py [-h] [--server SERVER] [--dataset-id DATASET_ID] [--overwrite] [--skip-checksum] [--multiprocessing MULTIPROCESSING] [--subset {full,preprocessed,stimuli}] download_directory 20 | 21 | Download the auditory EEG dataset from RDR. 22 | 23 | positional arguments: 24 | download_directory Path to download the dataset to. 25 | 26 | options: 27 | -h, --help show this help message and exit 28 | --server SERVER The server to download the dataset from. Default: "rdr.kuleuven.be" 29 | --dataset-id DATASET_ID 30 | The dataset ID to download. Default: "doi:10.48804/K3VSND" 31 | --overwrite Overwrite existing files. 32 | --skip-checksum Whether to skip checksums. 33 | --multiprocessing MULTIPROCESSING 34 | Number of cores to use for multiprocessing. Default: -1 (all cores), set to 0 or 1 to disable multiprocessing. 35 | --subset {full,preprocessed,stimuli} 36 | Download only a subset of the dataset. "full" downloads the full dataset, "preprocessed" downloads only the preprocessed data and "stimuli" downloads only the stimuli files. Default: full 37 | 38 | ``` 39 | -------------------------------------------------------------------------------- /download_code/download_script.py: -------------------------------------------------------------------------------- 1 | """Script to download the auditory EEG dataset from RDR.""" 2 | import argparse 3 | import os.path 4 | 5 | from download_code import DataverseDownloader, DataverseParser 6 | 7 | if __name__ == "__main__": 8 | download_dir = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | parser = argparse.ArgumentParser( 11 | description="Download the auditory EEG dataset from RDR." 12 | ) 13 | parser.add_argument( 14 | "--server", 15 | default="rdr.kuleuven.be", 16 | help='The server to download the dataset from. Default: "rdr.kuleuven.be"', 17 | ) 18 | parser.add_argument( 19 | "--dataset-id", 20 | default="doi:10.48804/K3VSND", 21 | help='The dataset ID to download. Default: "doi:10.48804/K3VSND"', 22 | ) 23 | parser.add_argument( 24 | "--overwrite", action="store_true", help="Overwrite existing files." 25 | ) 26 | parser.add_argument( 27 | "--skip-checksum", action="store_true", help="Whether to skip checksums." 28 | ) 29 | parser.add_argument( 30 | "--multiprocessing", 31 | type=int, 32 | default=-1, 33 | help="Number of cores to use for multiprocessing. " 34 | "Default: -1 (all cores), set to 0 or 1 to disable multiprocessing.", 35 | ) 36 | parser.add_argument( 37 | "--subset", 38 | choices=["full", "preprocessed", "stimuli"], 39 | default="full", 40 | help='Download only a subset of the dataset. ' 41 | '"full" downloads the full dataset, ' 42 | '"preprocessed" downloads only the preprocessed data and ' 43 | '"stimuli" downloads only the stimuli files. Default: full', 44 | ) 45 | parser.add_argument( 46 | "download_directory", type=str, help="Path to download the dataset to." 47 | ) 48 | 49 | args = parser.parse_args() 50 | 51 | if args.subset == "full": 52 | 53 | def filter_fn(path, file_id): 54 | return True 55 | 56 | elif args.subset == "preprocessed": 57 | 58 | def filter_fn(path, file_id): 59 | return path.startswith("derivatives/") 60 | 61 | elif args.subset == "stimuli": 62 | 63 | def filter_fn(path, file_id): 64 | return path.startswith("stimuli/") 65 | 66 | else: 67 | raise ValueError(f"Unknown subset {args.subset}") 68 | 69 | dataverse_parser = DataverseParser(args.server) 70 | file_id_mapping = dataverse_parser(args.dataset_id) 71 | downloader = DataverseDownloader( 72 | args.download_directory, 73 | args.server, 74 | overwrite=args.overwrite, 75 | multiprocessing=args.multiprocessing, 76 | check_md5=not args.skip_checksum, 77 | ) 78 | print( 79 | f"Starting download of set {args.subset} from {args.server} to " 80 | f"{args.download_directory}... (options: overwrite={args.overwrite}, " 81 | f"multiprocessing={args.multiprocessing})" 82 | ) 83 | print(f"This might take a while...") 84 | downloader(file_id_mapping, filter_fn=filter_fn) 85 | -------------------------------------------------------------------------------- /preprocessing_code/sparrKULee.yaml: -------------------------------------------------------------------------------- 1 | dataloaders: 2 | - name: sparrkulee_eeg_loader 3 | callable: GlobLoader 4 | glob_patterns: 5 | - {{ dataset_folder }}/sub-*/*/eeg/*_task-listeningActive_*.bdf* 6 | key: data_path 7 | 8 | pipelines: 9 | - callable: DefaultPipeline 10 | data_from: sparrkulee_eeg_loader 11 | steps: 12 | - callable: LinkStimulusToBrainResponse 13 | stimulus_data: 14 | callable: DefaultPipeline 15 | steps: 16 | - callable: LoadStimuli 17 | load_fn: 18 | callable: temp_stimulus_load_fn 19 | is_pointer: true 20 | - callable: GammatoneEnvelope 21 | # Uncomment if mel is not needed 22 | - callable: LibrosaMelSpectrogram 23 | power_factor: 0.6 24 | librosa_kwargs: 25 | callable: SparrKULeeSpectrogramKwargs 26 | - callable: ResamplePoly 27 | target_frequency: 64 28 | data_key: envelope_data 29 | sampling_frequency_key: stimulus_sr 30 | - callable: DefaultSave 31 | root_dir: {{ dataset_folder }}/{{ derivatives_folder}}/{{ preprocessed_stimuli_dir }} 32 | to_save: 33 | envelope: envelope_data 34 | # Uncomment if mel is not needed 35 | mel: spectrogram_data 36 | - callable: DefaultSave 37 | root_dir: {{ dataset_folder }}/{{ derivatives_folder}}/{{ preprocessed_stimuli_dir }} 38 | overwrite: false 39 | grouper: 40 | callable: BIDSStimulusGrouper 41 | bids_root: {{ dataset_folder }} 42 | mapping: 43 | stim_file: stimulus_path 44 | trigger_file: trigger_path 45 | subfolders: 46 | - stimuli 47 | - eeg 48 | - callable: LoadEEGNumpy 49 | unit_multiplier: 1000000 50 | channels_to_select: 51 | {% for channel in range(64) %} 52 | - {{ channel }} 53 | {% endfor %} 54 | 55 | - callable: SosFiltFilt 56 | filter_: 57 | callable: scipy.signal.butter 58 | N: 1 59 | Wn: 0.5 60 | btype: highpass 61 | fs: 1024 62 | output: sos 63 | emulate_matlab: true 64 | axis: 1 65 | - callable: InterpolateArtifacts 66 | - callable: AlignPeriodicBlockTriggers 67 | brain_trigger_processing_fn: 68 | callable: biosemi_trigger_processing_fn 69 | is_pointer: true 70 | - callable: SplitEpochs 71 | - callable: ArtifactRemovalMWF 72 | - callable: CommonAverageRereference 73 | - callable: ResamplePoly 74 | target_frequency: 64 75 | axis: 1 76 | - callable: DefaultSave 77 | root_dir: {{ dataset_folder }}/{{ derivatives_folder}}/{{ preprocessed_eeg_dir }} 78 | to_save: 79 | eeg: data 80 | overwrite: false 81 | filename_fn: 82 | callable: bids_filename_fn 83 | is_pointer: true 84 | clear_output: true 85 | 86 | config: 87 | parser: 88 | extra_paths: 89 | - {{ __filedir__ }}/sparrKULee.py 90 | logging: 91 | log_path: {{ __filedir__ }}/sparrKULee_{datetime}.log 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /technical_validation/models/linear.py: -------------------------------------------------------------------------------- 1 | """ This module contains linear backward model""" 2 | import tensorflow as tf 3 | 4 | 5 | 6 | 7 | @tf.function 8 | def pearson_loss_cut(y_true, y_pred, axis=1): 9 | """Pearson loss function. 10 | 11 | Parameters 12 | ---------- 13 | y_true: tf.Tensor 14 | True values. Shape is (batch_size, time_steps, n_features) 15 | y_pred: tf.Tensor 16 | Predicted values. Shape is (batch_size, time_steps, n_features) 17 | 18 | Returns 19 | ------- 20 | tf.Tensor 21 | Pearson loss. 22 | Shape is (batch_size, 1, n_features) 23 | """ 24 | return -pearson_tf(y_true[:, : tf.shape(y_pred)[1], :], y_pred, axis=axis) 25 | 26 | 27 | @tf.function 28 | def pearson_metric_cut(y_true, y_pred, axis=1): 29 | """Pearson metric function. 30 | 31 | Parameters 32 | ---------- 33 | y_true: tf.Tensor 34 | True values. Shape is (batch_size, time_steps, n_features) 35 | y_pred: tf.Tensor 36 | Predicted values. Shape is (batch_size, time_steps, n_features) 37 | 38 | Returns 39 | ------- 40 | tf.Tensor 41 | Pearson metric. 42 | Shape is (batch_size, 1, n_features) 43 | """ 44 | return tf.reduce_mean( pearson_tf(y_true[:, : tf.shape(y_pred)[1], :], y_pred, axis=axis), axis=-1) 45 | 46 | @tf.function 47 | def pearson_metric_cut_not_av(y_true, y_pred, axis=1): 48 | """Pearson metric function. 49 | 50 | Parameters 51 | ---------- 52 | y_true: tf.Tensor 53 | True values. Shape is (batch_size, time_steps, n_features) 54 | y_pred: tf.Tensor 55 | Predicted values. Shape is (batch_size, time_steps, n_features) 56 | 57 | Returns 58 | ------- 59 | tf.Tensor 60 | Pearson metric. 61 | Shape is (batch_size, 1, n_features) 62 | """ 63 | return pearson_tf(y_true[:, : tf.shape(y_pred)[1], :], y_pred, axis=axis) 64 | 65 | 66 | def simple_linear_model(integration_window=32, nb_filters=1, nb_channels=64, metric= pearson_metric_cut): 67 | inp = tf.keras.layers.Input( 68 | ( 69 | None, 70 | nb_channels, 71 | ) 72 | ) 73 | out = tf.keras.layers.Conv1D(nb_filters, integration_window)(inp) 74 | model = tf.keras.models.Model(inputs=[inp], outputs=[out]) 75 | model.compile( 76 | tf.keras.optimizers.Adam(), 77 | loss=pearson_loss_cut, 78 | metrics=[metric] 79 | ) 80 | return model 81 | 82 | 83 | 84 | def pearson_tf(y_true, y_pred, axis=1): 85 | """Pearson correlation function implemented in tensorflow. 86 | 87 | Parameters 88 | ---------- 89 | y_true: tf.Tensor 90 | Ground truth labels. Shape is (batch_size, time_steps, n_features) 91 | y_pred: tf.Tensor 92 | Predicted labels. Shape is (batch_size, time_steps, n_features) 93 | axis: int 94 | Axis along which to compute the pearson correlation. Default is 1. 95 | 96 | Returns 97 | ------- 98 | tf.Tensor 99 | Pearson correlation. 100 | Shape is (batch_size, 1, n_features) if axis is 1. 101 | """ 102 | # Compute the mean of the true and predicted values 103 | y_true_mean = tf.reduce_mean(y_true, axis=axis, keepdims=True) 104 | y_pred_mean = tf.reduce_mean(y_pred, axis=axis, keepdims=True) 105 | 106 | # Compute the numerator and denominator of the pearson correlation 107 | numerator = tf.reduce_sum( 108 | (y_true - y_true_mean) * (y_pred - y_pred_mean), 109 | axis=axis, 110 | keepdims=True, 111 | ) 112 | std_true = tf.reduce_sum(tf.square(y_true - y_true_mean), axis=axis, keepdims=True) 113 | std_pred = tf.reduce_sum(tf.square(y_pred - y_pred_mean), axis=axis, keepdims=True) 114 | denominator = tf.sqrt(std_true * std_pred) 115 | 116 | # Compute the pearson correlation 117 | return tf.math.divide_no_nan(numerator, denominator) 118 | -------------------------------------------------------------------------------- /technical_validation/models/dilated_convolutional_model.py: -------------------------------------------------------------------------------- 1 | """Default dilation model.""" 2 | import tensorflow as tf 3 | 4 | 5 | def dilation_model( 6 | time_window=None, 7 | eeg_input_dimension=64, 8 | env_input_dimension=1, 9 | layers=3, 10 | kernel_size=3, 11 | spatial_filters=8, 12 | dilation_filters=16, 13 | activation="relu", 14 | compile=True, 15 | inputs=tuple(), 16 | ): 17 | """Convolutional dilation model. 18 | 19 | Code was taken and adapted from 20 | https://github.com/exporl/eeg-matching-eusipco2020 21 | 22 | Parameters 23 | ---------- 24 | time_window : int or None 25 | Segment length. If None, the model will accept every time window input 26 | length. 27 | eeg_input_dimension : int 28 | number of channels of the EEG 29 | env_input_dimension : int 30 | dimemsion of the stimulus representation. 31 | if stimulus == envelope, env_input_dimension =1 32 | if stimulus == mel, env_input_dimension =28 33 | layers : int 34 | Depth of the network/Number of layers 35 | kernel_size : int 36 | Size of the kernel for the dilation convolutions 37 | spatial_filters : int 38 | Number of parallel filters to use in the spatial layer 39 | dilation_filters : int 40 | Number of parallel filters to use in the dilation layers 41 | activation : str or list or tuple 42 | Name of the non-linearity to apply after the dilation layers 43 | or list/tuple of different non-linearities 44 | compile : bool 45 | If model should be compiled 46 | inputs : tuple 47 | Alternative inputs 48 | 49 | Returns 50 | ------- 51 | tf.Model 52 | The dilation model 53 | 54 | 55 | References 56 | ---------- 57 | Accou, B., Jalilpour Monesi, M., Montoya, J., Van hamme, H. & Francart, T. 58 | Modeling the relationship between acoustic stimulus and EEG with a dilated 59 | convolutional neural network. In 2020 28th European Signal Processing 60 | Conference (EUSIPCO), 1175–1179, DOI: 10.23919/Eusipco47968.2020.9287417 61 | (2021). ISSN: 2076-1465. 62 | 63 | Accou, B., Monesi, M. J., hamme, H. V. & Francart, T. 64 | Predicting speech intelligibility from EEG in a non-linear classification 65 | paradigm. J. Neural Eng. 18, 066008, DOI: 10.1088/1741-2552/ac33e9 (2021). 66 | Publisher: IOP Publishing 67 | """ 68 | # If different inputs are required 69 | if len(inputs) == 3: 70 | eeg, env1, env2 = inputs[0], inputs[1], inputs[2] 71 | else: 72 | eeg = tf.keras.layers.Input(shape=[time_window, eeg_input_dimension]) 73 | env1 = tf.keras.layers.Input(shape=[time_window, env_input_dimension]) 74 | env2 = tf.keras.layers.Input(shape=[time_window, env_input_dimension]) 75 | 76 | # Activations to apply 77 | if isinstance(activation, str): 78 | activations = [activation] * layers 79 | else: 80 | activations = activation 81 | 82 | env_proj_1 = env1 83 | env_proj_2 = env2 84 | # Spatial convolution 85 | eeg_proj_1 = tf.keras.layers.Conv1D(spatial_filters, kernel_size=1)(eeg) 86 | 87 | # Construct dilation layers 88 | for layer_index in range(layers): 89 | # dilation on EEG 90 | eeg_proj_1 = tf.keras.layers.Conv1D( 91 | dilation_filters, 92 | kernel_size=kernel_size, 93 | dilation_rate=kernel_size**layer_index, 94 | strides=1, 95 | activation=activations[layer_index], 96 | )(eeg_proj_1) 97 | 98 | # Dilation on envelope data, share weights 99 | env_proj_layer = tf.keras.layers.Conv1D( 100 | dilation_filters, 101 | kernel_size=kernel_size, 102 | dilation_rate=kernel_size**layer_index, 103 | strides=1, 104 | activation=activations[layer_index], 105 | ) 106 | env_proj_1 = env_proj_layer(env_proj_1) 107 | env_proj_2 = env_proj_layer(env_proj_2) 108 | 109 | # Comparison 110 | cos1 = tf.keras.layers.Dot(1, normalize=True)([eeg_proj_1, env_proj_1]) 111 | cos2 = tf.keras.layers.Dot(1, normalize=True)([eeg_proj_1, env_proj_2]) 112 | 113 | # Classification 114 | out = tf.keras.layers.Dense(1, activation="sigmoid")( 115 | tf.keras.layers.Flatten()(tf.keras.layers.Concatenate()([cos1, cos2])) 116 | ) 117 | 118 | model = tf.keras.Model(inputs=[eeg, env1, env2], outputs=[out]) 119 | 120 | if compile: 121 | model.compile( 122 | optimizer=tf.keras.optimizers.Adam(), 123 | metrics=["acc"], 124 | loss=["binary_crossentropy"], 125 | ) 126 | print(model.summary()) 127 | return model 128 | -------------------------------------------------------------------------------- /technical_validation/util/split_and_normalize.py: -------------------------------------------------------------------------------- 1 | """Split data in sets and normalize (per recording).""" 2 | import glob 3 | import json 4 | import os 5 | import pickle 6 | import numpy as np 7 | 8 | 9 | if __name__ == "__main__": 10 | 11 | # Arguments for splitting and normalizing 12 | speech_features = ['envelope'] 13 | splits = [80, 10, 10] 14 | split_names = ['train', 'val', 'test'] 15 | overwrite = False 16 | 17 | # Calculate the split fraction 18 | split_fractions = [x/sum(splits) for x in splits] 19 | 20 | # Get the path to the config file 21 | main_folder = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) 22 | config_path = os.path.join(main_folder, 'config.json') 23 | 24 | # Load the config 25 | with open(config_path) as fp: 26 | config = json.load(fp) 27 | 28 | # Construct the necessary paths 29 | processed_eeg_folder = os.path.join(config["dataset_folder"], config["derivatives_folder"],config["preprocessed_eeg_folder"]) 30 | processed_stimuli_folder = os.path.join(config["dataset_folder"],config["derivatives_folder"], config["preprocessed_stimuli_folder"]) 31 | split_data_folder = os.path.join(config["dataset_folder"], config["derivatives_folder"], config["split_folder"]) 32 | 33 | # Create the output folder 34 | os.makedirs(split_data_folder, exist_ok=True) 35 | 36 | # Find all subjects 37 | all_subjects = glob.glob(os.path.join(processed_eeg_folder, "sub*")) 38 | nb_subjects = len(all_subjects) 39 | print(f"Found {nb_subjects} subjects to split/normalize") 40 | 41 | # Loop over subjects 42 | for subject_index, subject_path in enumerate(all_subjects): 43 | subject = os.path.basename(subject_path) 44 | print(f"Starting with subject {subject} ({subject_index + 1}/{nb_subjects})...") 45 | # Find all recordings 46 | all_recordings = glob.glob(os.path.join(subject_path, "*", "*.npy")) 47 | print(f"\tFound {len(all_recordings)} recordings for subject {subject}.") 48 | # Loop over recordings 49 | for recording_index, recording in enumerate(all_recordings): 50 | print(f"\tStarting with recording {recording} ({recording_index + 1}/{len(all_recordings)})...") 51 | 52 | # Load EEG from disk 53 | print(f"\t\tLoading EEG for {recording}") 54 | eeg = np.load(recording) 55 | 56 | # swap axes to have time as first dimension 57 | eeg = np.swapaxes(eeg, 0, 1) 58 | 59 | # keep only the 64 channels 60 | eeg = eeg[:, :64] 61 | 62 | # retrieve the stimulus name from the filename 63 | stimulus_filename = recording.split('_eeg.')[0].split('-audio-')[1] 64 | 65 | # Retrieve EEG data and pointer to the stimulus 66 | shortest_length = eeg.shape[0] 67 | 68 | # Create mapping between feature name and feature data 69 | all_data_for_recording = {"eeg": eeg} 70 | 71 | # Find corresponding stimuli for the EEG recording 72 | for feature_name in speech_features: 73 | # Load feature from disk 74 | print(f"\t\tLoading {feature_name} for recording {recording} ") 75 | stimulus_feature_path = os.path.join( 76 | processed_stimuli_folder, 77 | stimulus_filename + "_" + feature_name + ".npy", 78 | ) 79 | feature = np.load(stimulus_feature_path) 80 | # Calculate the shortest length 81 | shortest_length = min(feature.shape[0], shortest_length) 82 | # Update all_data_for_recording 83 | all_data_for_recording[feature_name] = feature 84 | 85 | # Do the actual splitting 86 | print(f"\t\tSplitting/normalizing recording {recording}...") 87 | for feature_name, feature in all_data_for_recording.items(): 88 | start_index = 0 89 | feature_mean = None 90 | feature_std = None 91 | 92 | for split_name, split_fraction in zip(split_names, split_fractions): 93 | end_index = start_index + int(shortest_length * split_fraction) 94 | 95 | # Cut the feature to the shortest length 96 | cut_feature = feature[start_index:end_index, ...] 97 | 98 | # Normalize the feature 99 | if feature_mean is None: 100 | feature_mean = np.mean(cut_feature, axis=0) 101 | feature_std = np.std(cut_feature, axis=0) 102 | norm_feature = (cut_feature - feature_mean)/feature_std 103 | 104 | # Save the feature 105 | save_filename = f"{split_name}_-_{subject}_-_{stimulus_filename}_-_{feature_name}.npy" 106 | save_path = os.path.join(split_data_folder, save_filename) 107 | if not os.path.exists(save_path) or overwrite: 108 | np.save(save_path, cut_feature) 109 | else: 110 | print(f"\t\tSkipping {save_filename} because it already exists") 111 | start_index = end_index 112 | 113 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 163 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 164 | 165 | # User-specific stuff 166 | .idea/**/workspace.xml 167 | .idea/**/tasks.xml 168 | .idea/**/usage.statistics.xml 169 | .idea/**/dictionaries 170 | .idea/**/shelf 171 | 172 | # AWS User-specific 173 | .idea/**/aws.xml 174 | 175 | # Generated files 176 | .idea/**/contentModel.xml 177 | 178 | # Sensitive or high-churn files 179 | .idea/**/dataSources/ 180 | .idea/**/dataSources.ids 181 | .idea/**/dataSources.local.xml 182 | .idea/**/sqlDataSources.xml 183 | .idea/**/dynamic.xml 184 | .idea/**/uiDesigner.xml 185 | .idea/**/dbnavigator.xml 186 | 187 | # Gradle 188 | .idea/**/gradle.xml 189 | .idea/**/libraries 190 | 191 | # Gradle and Maven with auto-import 192 | # When using Gradle or Maven with auto-import, you should exclude module files, 193 | # since they will be recreated, and may cause churn. Uncomment if using 194 | # auto-import. 195 | # .idea/artifacts 196 | # .idea/compiler.xml 197 | # .idea/jarRepositories.xml 198 | # .idea/modules.xml 199 | # .idea/*.iml 200 | # .idea/modules 201 | # *.iml 202 | # *.ipr 203 | 204 | # CMake 205 | cmake-build-*/ 206 | 207 | # Mongo Explorer plugin 208 | .idea/**/mongoSettings.xml 209 | 210 | # File-based project format 211 | *.iws 212 | 213 | # IntelliJ 214 | out/ 215 | 216 | # mpeltonen/sbt-idea plugin 217 | .idea_modules/ 218 | 219 | # JIRA plugin 220 | atlassian-ide-plugin.xml 221 | 222 | # Cursive Clojure plugin 223 | .idea/replstate.xml 224 | 225 | # SonarLint plugin 226 | .idea/sonarlint/ 227 | 228 | # Crashlytics plugin (for Android Studio and IntelliJ) 229 | com_crashlytics_export_strings.xml 230 | crashlytics.properties 231 | crashlytics-build.properties 232 | fabric.properties 233 | 234 | # Editor-based Rest Client 235 | .idea/httpRequests 236 | 237 | # Android studio 3.1+ serialized cache file 238 | .idea/caches/build_file_checksums.ser 239 | 240 | 241 | # No condor job files 242 | technical_validation/condor/ 243 | *.job 244 | technical_validation/experiments/results* 245 | technical_validation/experiments/figures/ 246 | # No tests with classified data 247 | tests/test_classified/* 248 | # No .py in progress files 249 | *.py~ -------------------------------------------------------------------------------- /technical_validation/experiments/regression_vlaai.py: -------------------------------------------------------------------------------- 1 | """Example experiment for the VLAAI model.""" 2 | import glob 3 | import json 4 | import logging 5 | import os 6 | import tensorflow as tf 7 | import sys 8 | 9 | 10 | from technical_validation.models.vlaai import vlaai, pearson_loss, pearson_metric 11 | from technical_validation.util.dataset_generator import RegressionDataGenerator, create_tf_dataset 12 | 13 | 14 | def evaluate_model(model, test_dict): 15 | """Evaluate a model. 16 | 17 | Parameters 18 | ---------- 19 | model: tf.keras.Model 20 | Model to evaluate. 21 | test_dict: dict 22 | Mapping between a subject and a tf.data.Dataset containing the test 23 | set for the subject. 24 | 25 | Returns 26 | ------- 27 | dict 28 | Mapping between a subject and the loss/evaluation score on the test set 29 | """ 30 | evaluation = {} 31 | for subject, ds_test in test_dict.items(): 32 | logging.info(f"Scores for subject {subject}:") 33 | results = model.evaluate(ds_test, verbose=2) 34 | metrics = model.metrics_names 35 | evaluation[subject] = dict(zip(metrics, results)) 36 | return evaluation 37 | 38 | 39 | if __name__ == "__main__": 40 | 41 | gpus = tf.config.list_physical_devices('GPU') 42 | print(gpus) 43 | # if gpus: 44 | # try: 45 | # # Currently, memory growth needs to be the same across GPUs 46 | # for gpu in gpus: 47 | # tf.config.experimental.set_memory_growth(gpu, True) 48 | # logical_gpus = tf.config.list_logical_devices('GPU') 49 | # print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 50 | # except RuntimeError as e: 51 | # # Memory growth must be set before GPUs have been initialized 52 | # print(e) 53 | # Parameters 54 | # Length of the decision window 55 | window_length = 10 * 64 # 10 seconds 56 | # Hop length between two consecutive decision windows 57 | hop_length = 64 58 | epochs = 100 59 | patience = 5 60 | batch_size = 64 61 | only_evaluate = False 62 | training_log_filename = "training_log_%d.csv" %window_length 63 | results_filename = 'eval_%d.json' % window_length 64 | 65 | 66 | # Get the path to the config gile 67 | experiments_folder = os.path.dirname(__file__) 68 | main_folder = os.path.dirname(os.path.dirname(experiments_folder)) 69 | config_path = os.path.join(main_folder, 'config.json') 70 | 71 | # Load the config 72 | with open(config_path) as fp: 73 | config = json.load(fp) 74 | 75 | # Provide the path of the dataset 76 | # which is split already to train, val, test 77 | 78 | data_folder = os.path.join(config["dataset_folder"], config["derivatives"], config["split_folder"]) 79 | stimulus_features = ["envelope"] 80 | features = ["eeg"] + stimulus_features 81 | 82 | # Create a directory to store (intermediate) results 83 | results_folder = os.path.join(experiments_folder, "results_vlaai") 84 | os.makedirs(results_folder, exist_ok=True) 85 | 86 | # create the model 87 | model = vlaai() 88 | model.compile(tf.keras.optimizers.Adam(), loss=pearson_loss, metrics=[pearson_metric]) 89 | model_path = os.path.join(results_folder, "model_%d.h5" % window_length) 90 | 91 | if only_evaluate: 92 | model.load_weights(model_path) 93 | else: 94 | 95 | train_files = [x for x in glob.glob(os.path.join(data_folder, "train_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features ] 96 | # Create list of numpy array files 97 | train_generator = RegressionDataGenerator(train_files, window_length) 98 | dataset_train = create_tf_dataset(train_generator, window_length, None, hop_length, batch_size, data_types=(tf.float32, tf.float32), feature_dims=(64,1)) 99 | 100 | # Create the generator for the validation set 101 | val_files = [x for x in glob.glob(os.path.join(data_folder, "val_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features ] 102 | val_generator = RegressionDataGenerator(val_files, window_length) 103 | dataset_val = create_tf_dataset(val_generator, window_length, None, hop_length, batch_size, data_types=(tf.float32, tf.float32), feature_dims=(64,1)) 104 | 105 | # Train the model 106 | model.fit( 107 | dataset_train, 108 | epochs=epochs, 109 | validation_data=dataset_val, 110 | callbacks=[ 111 | tf.keras.callbacks.ModelCheckpoint(model_path, save_best_only=True), 112 | tf.keras.callbacks.CSVLogger(os.path.join(results_folder, training_log_filename)), 113 | tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=True), 114 | ], 115 | ) 116 | 117 | # Evaluate the model on test set 118 | # Create a dataset generator for each test subject 119 | test_files = [x for x in glob.glob(os.path.join(data_folder, "test_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features] 120 | # Get all different subjects from the test set 121 | subjects = list(set([os.path.basename(x).split("_-_")[1] for x in test_files])) 122 | datasets_test = {} 123 | results_filename = 'eval_pretrained_%d.json' % window_length 124 | # Create a generator for each subject 125 | for sub in subjects: 126 | files_test_sub = [f for f in test_files if sub in os.path.basename(f)] 127 | 128 | test_generator = RegressionDataGenerator(files_test_sub, window_length) 129 | datasets_test[sub] = create_tf_dataset(test_generator, window_length, None, hop_length, batch_size, data_types=(tf.float32, tf.float32), feature_dims=(64,1)) 130 | 131 | # Evaluate the model 132 | evaluation = evaluate_model(model, datasets_test) 133 | 134 | # We can save our results in a json encoded file 135 | results_path = os.path.join(results_folder, results_filename) 136 | with open(results_path, "w") as fp: 137 | json.dump(evaluation, fp) 138 | logging.info(f"Results saved at {results_path}") 139 | 140 | -------------------------------------------------------------------------------- /technical_validation/experiments/match_mismatch_dilated_convolutional_model.py: -------------------------------------------------------------------------------- 1 | """Example experiment for dilation model.""" 2 | import glob 3 | import json 4 | import logging 5 | import os 6 | os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private' 7 | os.environ['TF_XLA_FLAGS'] = "--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" 8 | import tensorflow as tf 9 | import sys 10 | 11 | 12 | from technical_validation.models.dilated_convolutional_model import dilation_model 13 | from technical_validation.util.dataset_generator import MatchMismatchDataGenerator, default_batch_equalizer_fn, create_tf_dataset 14 | 15 | 16 | def evaluate_model(model, test_dict): 17 | """Evaluate a model. 18 | 19 | Parameters 20 | ---------- 21 | model: tf.keras.Model 22 | Model to evaluate. 23 | test_dict: dict 24 | Mapping between a subject and a tf.data.Dataset containing the test 25 | set for the subject. 26 | 27 | Returns 28 | ------- 29 | dict 30 | Mapping between a subject and the loss/evaluation score on the test set 31 | """ 32 | evaluation = {} 33 | for subject, ds_test in test_dict.items(): 34 | logging.info(f"Scores for subject {subject}:") 35 | results = model.evaluate(ds_test, verbose=2) 36 | metrics = model.metrics_names 37 | evaluation[subject] = dict(zip(metrics, results)) 38 | return evaluation 39 | 40 | 41 | if __name__ == "__main__": 42 | # Parameters 43 | # Length of the decision window 44 | window_length = 5 * 64 # 5 seconds 45 | # Hop length between two consecutive decision windows 46 | hop_length = 64 47 | # Number of samples (space) between end of matched speech and beginning of mismatched speech 48 | spacing = 64 49 | epochs = 100 50 | patience = 5 51 | batch_size = 64 52 | 53 | only_evaluate = False 54 | 55 | training_log_filename = "training_log.csv" 56 | results_filename = 'eval.json' 57 | 58 | 59 | # Get the path to the config gile 60 | experiments_folder = os.path.dirname(__file__) 61 | main_folder = os.path.dirname(os.path.dirname(experiments_folder)) 62 | config_path = os.path.join(main_folder, 'config.json') 63 | 64 | # Load the config 65 | with open(config_path) as fp: 66 | config = json.load(fp) 67 | 68 | # Provide the path of the dataset 69 | # which is split already to train, val, test 70 | data_folder = os.path.join(config["dataset_folder"],config["derivatives"], config["split_folder"]) 71 | 72 | # stimulus feature which will be used for training the model. Can be either 'envelope' ( dimension 1) 73 | stimulus_features = ["envelope"] 74 | stimulus_dimension = 1 75 | 76 | features = ["eeg"] + stimulus_features 77 | 78 | # Create a directory to store (intermediate) results 79 | results_folder = os.path.join(experiments_folder, "results_dilated_convolutional_model") 80 | 81 | os.makedirs(results_folder, exist_ok=True) 82 | 83 | # create dilation model 84 | model = dilation_model(time_window=window_length, eeg_input_dimension=64, env_input_dimension=stimulus_dimension) 85 | model_path = os.path.join(results_folder, "model.h5") 86 | 87 | if only_evaluate: 88 | model = tf.keras.models.load_model(model_path) 89 | else: 90 | 91 | train_files = [x for x in glob.glob(os.path.join(data_folder, "train_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features] 92 | # Create list of numpy array files 93 | train_generator = MatchMismatchDataGenerator(train_files, window_length, spacing=spacing) 94 | dataset_train = create_tf_dataset(train_generator, window_length, default_batch_equalizer_fn, hop_length, batch_size) 95 | 96 | # Create the generator for the validation set 97 | val_files = [x for x in glob.glob(os.path.join(data_folder, "val_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features] 98 | val_generator = MatchMismatchDataGenerator(val_files, window_length, spacing=spacing) 99 | dataset_val = create_tf_dataset(val_generator, window_length, default_batch_equalizer_fn, hop_length, batch_size) 100 | 101 | # Train the model 102 | model.fit( 103 | dataset_train, 104 | epochs=epochs, 105 | validation_data=dataset_val, 106 | callbacks=[ 107 | tf.keras.callbacks.ModelCheckpoint(model_path, save_best_only=True), 108 | tf.keras.callbacks.CSVLogger(os.path.join(results_folder, training_log_filename)), 109 | tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=True), 110 | ], 111 | ) 112 | 113 | 114 | 115 | # Evaluate the model on test set 116 | # Create a dataset generator for each test subject 117 | test_files = [x for x in glob.glob(os.path.join(data_folder, "test_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features ] 118 | # Get all different subjects from the test set 119 | subjects = list(set([os.path.basename(x).split("_-_")[1] for x in test_files])) 120 | 121 | 122 | test_dict = {} 123 | # evaluate on different window lengths 124 | window_lengths = [64,128, 256,320, 640, 20*64] 125 | for window_length in window_lengths: 126 | model = dilation_model(time_window=window_length, eeg_input_dimension=64, 127 | env_input_dimension=stimulus_dimension) 128 | model_path = os.path.join(results_folder, "model.h5") 129 | model.load_weights(model_path) 130 | 131 | datasets_test = {} 132 | # Create a generator for each subject 133 | for sub in subjects: 134 | files_test_sub = [f for f in test_files if sub in os.path.basename(f)] 135 | test_generator = MatchMismatchDataGenerator(files_test_sub, window_length, spacing=spacing) 136 | datasets_test[sub] = create_tf_dataset(test_generator, window_length, default_batch_equalizer_fn, hop_length, 1,) 137 | 138 | # Evaluate the model 139 | evaluation = evaluate_model(model, datasets_test) 140 | 141 | # We can save our results in a json encoded file 142 | results_path = os.path.join(results_folder, results_filename.split(".")[0] + f"_{window_length}"+ ".json") 143 | with open(results_path, "w") as fp: 144 | json.dump(evaluation, fp) 145 | logging.info(f"Results saved at {results_path}") 146 | 147 | 148 | -------------------------------------------------------------------------------- /technical_validation/models/vlaai.py: -------------------------------------------------------------------------------- 1 | """Code to construct the VLAAI network. 2 | Code was extrcted from https://github.com/exporl/vlaai 3 | """ 4 | import tensorflow as tf 5 | 6 | 7 | def extractor( 8 | filters=(256, 256, 256, 128, 128), 9 | kernels=(8,) * 5, 10 | input_channels=64, 11 | normalization_fn=lambda x: tf.keras.layers.LayerNormalization()(x), 12 | activation_fn=lambda x: tf.keras.layers.LeakyReLU()(x), 13 | name="extractor", 14 | ): 15 | """Construct the extractor model. 16 | 17 | Parameters 18 | ---------- 19 | filters: Sequence[int] 20 | Number of filters for each layer. 21 | kernels: Sequence[int] 22 | Kernel size for each layer. 23 | input_channels: int 24 | Number of EEG channels in the input 25 | normalization_fn: Callable[[tf.Tensor], tf.Tensor] 26 | Function to normalize the contents of a tensor. 27 | activation_fn: Callable[[tf.Tensor], tf.Tensor] 28 | Function to apply an activation function to the contents of a tensor. 29 | name: str 30 | Name of the model. 31 | 32 | Returns 33 | ------- 34 | tf.keras.models.Model 35 | The extractor model. 36 | """ 37 | eeg = tf.keras.layers.Input((None, input_channels)) 38 | 39 | x = eeg 40 | 41 | if len(filters) != len(kernels): 42 | raise ValueError("'filters' and 'kernels' must have the same length") 43 | 44 | # Add the convolutional layers 45 | for filter_, kernel in zip(filters, kernels): 46 | x = tf.keras.layers.Conv1D(filter_, kernel)(x) 47 | x = normalization_fn(x) 48 | x = activation_fn(x) 49 | x = tf.keras.layers.ZeroPadding1D((0, kernel - 1))(x) 50 | 51 | return tf.keras.models.Model(inputs=[eeg], outputs=[x], name=name) 52 | 53 | 54 | def output_context( 55 | filter_=64, 56 | kernel=32, 57 | input_channels=64, 58 | normalization_fn=lambda x: tf.keras.layers.LayerNormalization()(x), 59 | activation_fn=lambda x: tf.keras.layers.LeakyReLU()(x), 60 | name="output_context_model", 61 | ): 62 | """Construct the output context model. 63 | 64 | Parameters 65 | ---------- 66 | filter_: int 67 | Number of filters for the convolutional layer. 68 | kernel: int 69 | Kernel size for the convolutional layer. 70 | input_channels: int 71 | Number of EEG channels in the input. 72 | normalization_fn: Callable[[tf.Tensor], tf.Tensor] 73 | Function to normalize the contents of a tensor. 74 | activation_fn: Callable[[tf.Tensor], tf.Tensor] 75 | Function to apply an activation function to the contents of a tensor. 76 | name: str 77 | Name of the model. 78 | 79 | Returns 80 | ------- 81 | tf.keras.models.Model 82 | The output context model. 83 | """ 84 | inp = tf.keras.layers.Input((None, input_channels)) 85 | x = tf.keras.layers.ZeroPadding1D((kernel - 1, 0))(inp) 86 | x = tf.keras.layers.Conv1D(filter_, kernel)(x) 87 | x = normalization_fn(x) 88 | x = activation_fn(x) 89 | return tf.keras.models.Model(inputs=[inp], outputs=[x], name=name) 90 | 91 | 92 | def vlaai( 93 | nb_blocks=4, 94 | extractor_model=None, 95 | output_context_model=None, 96 | use_skip=True, 97 | input_channels=64, 98 | output_dim=1, 99 | name="vlaai", 100 | ): 101 | """Construct the VLAAI model. 102 | 103 | Parameters 104 | ---------- 105 | nb_blocks: int 106 | Number of repeated blocks to use. 107 | extractor_model: Callable[[tf.Tensor], tf.Tensor] 108 | The extractor model to use. 109 | output_context_model: Callable[[tf.Tensor], tf.Tensor] 110 | The output context model to use. 111 | use_skip: bool 112 | Whether to use skip connections. 113 | input_channels: int 114 | Number of EEG channels in the input. 115 | output_dim: int 116 | Number of output dimensions. 117 | name: str 118 | Name of the model. 119 | 120 | Returns 121 | ------- 122 | tf.keras.models.Model 123 | The VLAAI model. 124 | """ 125 | if extractor_model is None: 126 | extractor_model = extractor() 127 | if output_context_model is None: 128 | output_context_model = output_context() 129 | 130 | eeg = tf.keras.layers.Input((None, input_channels)) 131 | 132 | # If using skip connections: start with x set to zero 133 | if use_skip: 134 | x = tf.zeros_like(eeg) 135 | else: 136 | x = eeg 137 | 138 | # Iterate over the blocks 139 | for i in range(nb_blocks): 140 | if use_skip: 141 | x = extractor_model(eeg + x) 142 | else: 143 | x = extractor_model(x) 144 | x = tf.keras.layers.Dense(input_channels)(x) 145 | x = output_context_model(x) 146 | 147 | x = tf.keras.layers.Dense(output_dim)(x) 148 | 149 | return tf.keras.models.Model(inputs=[eeg], outputs=[x], name=name) 150 | 151 | 152 | def pearson_tf(y_true, y_pred, axis=1): 153 | """Pearson correlation function implemented in tensorflow. 154 | 155 | Parameters 156 | ---------- 157 | y_true: tf.Tensor 158 | Ground truth labels. Shape is (batch_size, time_steps, n_features) 159 | y_pred: tf.Tensor 160 | Predicted labels. Shape is (batch_size, time_steps, n_features) 161 | axis: int 162 | Axis along which to compute the pearson correlation. Default is 1. 163 | 164 | Returns 165 | ------- 166 | tf.Tensor 167 | Pearson correlation. 168 | Shape is (batch_size, 1, n_features) if axis is 1. 169 | """ 170 | # Compute the mean of the true and predicted values 171 | y_true_mean = tf.reduce_mean(y_true, axis=axis, keepdims=True) 172 | y_pred_mean = tf.reduce_mean(y_pred, axis=axis, keepdims=True) 173 | 174 | # Compute the numerator and denominator of the pearson correlation 175 | numerator = tf.reduce_sum( 176 | (y_true - y_true_mean) * (y_pred - y_pred_mean), 177 | axis=axis, 178 | keepdims=True, 179 | ) 180 | std_true = tf.reduce_sum(tf.square(y_true - y_true_mean), axis=axis, keepdims=True) 181 | std_pred = tf.reduce_sum(tf.square(y_pred - y_pred_mean), axis=axis, keepdims=True) 182 | denominator = tf.sqrt(std_true * std_pred) 183 | 184 | # Compute the pearson correlation 185 | return tf.math.divide_no_nan(numerator, denominator) 186 | 187 | 188 | @tf.function 189 | def pearson_loss(y_true, y_pred, axis=1): 190 | """Pearson loss function. 191 | 192 | Parameters 193 | ---------- 194 | y_true: tf.Tensor 195 | True values. Shape is (batch_size, time_steps, n_features) 196 | y_pred: tf.Tensor 197 | Predicted values. Shape is (batch_size, time_steps, n_features) 198 | 199 | Returns 200 | ------- 201 | tf.Tensor 202 | Pearson loss. 203 | Shape is (batch_size, 1, n_features) 204 | """ 205 | return -pearson_tf(y_true, y_pred, axis=axis) 206 | 207 | 208 | @tf.function 209 | def pearson_metric(y_true, y_pred, axis=1): 210 | """Pearson metric function. 211 | 212 | Parameters 213 | ---------- 214 | y_true: tf.Tensor 215 | True values. Shape is (batch_size, time_steps, n_features) 216 | y_pred: tf.Tensor 217 | Predicted values. Shape is (batch_size, time_steps, n_features) 218 | 219 | Returns 220 | ------- 221 | tf.Tensor 222 | Pearson metric. 223 | Shape is (batch_size, 1, n_features) 224 | """ 225 | return pearson_tf(y_true, y_pred, axis=axis) 226 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code to preprocess the SparrKULee dataset 2 | ========================================= 3 | This is the codebase to preprocess and validate the [SparrKULee](https://doi.org/10.48804/K3VSND) dataset. 4 | This codebase consist of two main parts: 5 | 1) preprocessing code, to preprocess the raw data into an easily usable format 6 | 2) technical validation code, to validate the technical quality of the dataset. 7 | This code is used to generate the results in the dataset paper and assumes that the preprocessing pipeline has been run 8 | 9 | Requirements 10 | ------------ 11 | 12 | Python >= 3.7 13 | 14 | # General setup 15 | 16 | Steps to get a working setup: 17 | 18 | ## 1. Clone this repository and install the [requirements.txt](requirements.txt) 19 | ```bash 20 | # Clone this repository 21 | git clone https://github.com/exporl/auditory-eeg-dataset 22 | 23 | # Go to the root folder 24 | cd auditory-eeg-dataset 25 | 26 | # Optional: install a virtual environment 27 | python3 -m venv venv # Optional 28 | source venv/bin/activate # Optional 29 | 30 | # Install requirements.txt 31 | python3 -m install requirements.txt 32 | ``` 33 | 34 | ## 2. [Download (parts of the) data](download_code/README.md) 35 | 36 | **We recommend accessing the dataset via our [website](https://homes.esat.kuleuven.be/~spchdata/corpora/auditory_eeg_data/), where we also share instructions to mount the dataset as an alternative to downloading.** 37 | 38 | The dataset is also hosted on the [KU Leuven RDR website](https://doi.org/10.48804/K3VSND) and is accessible through DOI ([https://doi.org/10.48804/K3VSND](https://doi.org/10.48804/K3VSND)). 39 | 40 | However, due to the dataset size/structure and the limitations of the UI of the KU Leuven RDR website, we also provide a [direct download link for the entire dataset in `.zip` format](https://rdr.kuleuven.be/api/access/dataset/:persistentId/?persistentId=doi:10.48804/K3VSND), a [onedrive repository containing then entire dataset split up into smaller files](https://kuleuven-my.sharepoint.com/:f:/g/personal/lies_bollens_kuleuven_be/EulH76nkcwxIuK--XJhLxKQBaX8_GgAX-rTKK7mskzmAZA?e=N6M5Ll) and a [tool to download (subsets of) the dataset robustly](download_code/README.md). 41 | For more information about the tool, see [download_code/README.md](download_code/README.md). 42 | 43 | Due to privacy concerns, not all data is publically available. Users requesting access to these files should send a mail to the authors (lies.bollens@kuleuven.be ; bernd.accou@kuleuven.be) , stating what they want to use the data for. Access will be granted to non-commercial users, complying to the CC-BY-NC-4.0 licence 44 | 45 | When you want to directly start from the preprocessed data (which is the output you will get when running the file 46 | [preprocessing_code/examples/auditory_eeg_dataset.py](preprocessing_code/examples/auditory_eeg_dataset.py)), 47 | you can download the **derivatives** folder. This folder contains all the necessary files to run the technical validation. This can also be downloaded using [the download tool](download_code/README.md) as follows: 48 | 49 | ```bash 50 | python3 download_code/download_script.py --subset preprocessed /path/to/local/folder 51 | ``` 52 | 53 | 54 | ## 3. Adjust the [config.json](config.json) accordingly 55 | 56 | The [config.json](config.json) defining the folder names and structure for the data and derivatives folder. 57 | Adjust `dataset_folder` in the [config.json](config.json) file from `null` to the absolute path to the folder containing all data. 58 | 59 | 60 | OK, you should be all setup now! 61 | 62 | Preprocessing code 63 | ================== 64 | 65 | This repository uses the [brain_pipe package](https://github.com/exporl/brain_pipe) 66 | to preprocess the data. It is installed automatically when installing the [requirements.txt](requirements.txt). 67 | You are invited to contribute to the [brain_pipe package](https://github.com/exporl/brain_pipe) package, if you want to add new preprocessing steps. 68 | Documentation for the brain_pipe package can be found [here](https://exporl.github.io/brain_pipe/). 69 | 70 | Example usage 71 | ------------- 72 | 73 | There are multiple ways to run the preprocessing pipeline, specified below. 74 | 75 | **Warning:** the script and the YAML file will create both Mel spectrograms and envelope representations of the stimuli. 76 | If this is not desired, you can comment out the appropriate lines. 77 | 78 | **Make sure your [brain_pipe](brain_pipe) version is up to date (>= 0.0.3)!** 79 | You can ensure this by running `pip3 install --upgrade brain_pipe` or `pip3 install --upgrade -r requirements.txt`. 80 | 81 | ### 1. Use the python script [preprocessing_code/sparrKULee.py](preprocessing_code/sparrKULee.py) 82 | 83 | ```bash 84 | python3 preprocessing_code/sparrKULee.py 85 | ``` 86 | 87 | Different options (such as the number of parallel processes) can be specified from the command line. 88 | For more information, run : 89 | 90 | ```bash 91 | python3 preprocessing_code/sparrKULee.py --help. 92 | ``` 93 | 94 | ### 2. Use the YAML file with the [brain_pipe](https://github.com/exporl/brain_pipe) CLI 95 | 96 | For this option, you will have to fill in the `--dataset_folder`, `--derivatives_folder`, 97 | `--preprocessed_stimuli_dir` and `--preprocessed_eeg_dir` with the values from the [config.json](config.json) file. 98 | 99 | ```bash 100 | brain_pipe preprocessing_code/sparrKULee.yaml --dataset_folder {/path/to/dataset} --derivatives_folder {derivatives_folder} --preprocessed_stimuli_dir {preprocessed_stimuli_dir} --preprocessed_eeg_dir {preprocessed_eeg_dir} 101 | ``` 102 | 103 | Optionally, you could read the [config.json](config.json) file directly from the command line: 104 | 105 | ```bash 106 | brain_pipe preprocessing_code/sparrKULee.yaml $(python3 -c "import json; f=open('config.json'); d=json.load(f); f.close(); print(' '.join([f'--{x}={y}' for x,y in d.items() if 'split_folder' != x]))") 107 | ``` 108 | 109 | For more information about the [brain_pipe](https://github.com/exporl/brain_pipe) CLI, 110 | see the appriopriate documentation for the [CLI](https://exporl.github.io/brain_pipe/cli.html) and [configuration files (e.g. YAML)](https://exporl.github.io/brain_pipe/configuration.html) 111 | 112 | Technical validation 113 | ==================== 114 | This repository contains code to validate the preprocessed dataset using esthablished models. 115 | Running this code will yield the results summarized in the paper. (LINK TO PAPER) 116 | 117 | Prerequisites 118 | ------------- 119 | The technical validation code assumes that the preprocessing pipeline has been run and that the derivatives folder is available. 120 | The derivatives folder contains the preprocessed data and the necessary files to run the technical validation code. 121 | Either download the derivatives folder directly from the online dataset or run the preprocessing pipeline yourself [preprocessing_code/examples/auditory_eeg_dataset.py](preprocessing_code/examples/auditory_eeg_dataset.py). 122 | 123 | Example usage 124 | ------------- 125 | 126 | We have defined some ready-to-go experiments, to replicate the results summarized in the dataset paper. 127 | All these experiments use split (into training/validation/test partitions) and normalised data, which can be obtained by 128 | running [technical_validation/util/split_and_normalize.py](technical_validation/util/split_and_normalize.py). 129 | 130 | The experiment files live in the [technical_validation/experiments](technical_validation/experiments) folder. The training log, 131 | best model and evaluation results will be stored in a folder called 132 | `results_{experiment_name}`. 133 | 134 | To replicate the results summarized in the dataset paper, run the following experiments: 135 | ```bash 136 | # train the dilated convolutional model introduced by Accou et al.(https://doi.org/10.1088/1741-2552/ac33e9) 137 | match_mismatch_dilated_convolutional_model.py 138 | 139 | # train a simple linear backward model, reconstructing the envelope from EEG 140 | # using filtered data in different frequency bands 141 | # simple linear baseline model with Pearson correlation as a loss function, similar to the baseline model used in Accou et al (2022) (https://www.biorxiv.org/content/10.1101/2022.09.28.509945). 142 | 143 | regression_linear_backwards_model.py --highpass 0.5 -lowpass 30 144 | regression_linear_backwards_model.py --highpass 0.5 -lowpass 4 145 | regression_linear_backwards_model.py --highpass 4 -lowpass 8 146 | regression_linear_backwards_model.py --highpass 8 -lowpass 14 147 | regression_linear_backwards_model.py --highpass 14 -lowpass 30 148 | 149 | # train a simple linear forward model, predicting the EEG response from the envelope, 150 | # using filtered data in different frequency bands 151 | regression_linear_forward.py --highpass 0.5 -lowpass 30 152 | regression_linear_forward.py --highpass 0.5 -lowpass 4 153 | 154 | # train/evaluate the VLAAI model as proposed by Accou et al (2022) (https://www.biorxiv.org/content/10.1101/2022.09.28.509945). You can find a pre-trained model at VLAAI's github page (https://github.com/exporl/vlaai). 155 | regression_vlaai.py 156 | ``` 157 | 158 | Finally, you can generate the plots as shown in the dataset paper by running the [technical_validation/util/plot_results.py](technical_validation/util/plot_results.py) script 159 | 160 | -------------------------------------------------------------------------------- /download_code/__init__.py: -------------------------------------------------------------------------------- 1 | """Code to parse and download Dataverse datasets.""" 2 | import datetime 3 | import hashlib 4 | import json 5 | import multiprocessing as mp 6 | import os 7 | import urllib.request 8 | import urllib.request 9 | 10 | 11 | class DataverseDownloader: 12 | """Download files from a Dataverse dataset.""" 13 | 14 | def __init__( 15 | self, 16 | download_path, 17 | server, 18 | overwrite=False, 19 | check_md5=True, 20 | multiprocessing=-1, 21 | datetime_format="%Y-%m-%d %H:%M:%S", 22 | ): 23 | """Create a new DataverseDownloader. 24 | 25 | Parameters 26 | ---------- 27 | download_path: str 28 | The path to download the files to. The path will be created if it does not 29 | exist. 30 | server: str 31 | The hostname of the server to download the files from. 32 | overwrite: bool 33 | Whether to overwrite existing files. 34 | check_md5: bool 35 | Whether to check the MD5 checksum of the downloaded files. 36 | multiprocessing: int 37 | The number of cores to use for multiprocessing. Set to 0 or 1 to disable 38 | multiprocessing. The default -1 uses all available cores. 39 | datetime_format: str 40 | The datetime format to use for printing the start and end time of the 41 | download. 42 | """ 43 | self.download_path = download_path 44 | self.overwrite = overwrite 45 | self.check_md5 = check_md5 46 | self.multiprocessing = multiprocessing 47 | self.server = server 48 | self.datetime_format = datetime_format 49 | self._total = 0 50 | 51 | def get_url(self, file_id): 52 | """Get the download URL for a file ID. 53 | 54 | Parameters 55 | ---------- 56 | file_id: str 57 | The file ID to get the download URL for. 58 | 59 | Returns 60 | ------- 61 | str 62 | The download URL for the file ID. 63 | """ 64 | return f"https://{self.server}/api/access/datafile/{file_id}?gbrecs=true" 65 | 66 | def __call__(self, file_id_mapping, filter_fn=lambda x, y: True): 67 | """Download the files from a file ID mapping. 68 | 69 | Parameters 70 | ---------- 71 | file_id_mapping: Mapping[str, Mapping[str, Any]] 72 | A mapping from the path (relative to self.download_path) to save the file 73 | and another mapping containing at least 'id' as a key and 'md5' as a key 74 | (only necessary if self.check_md5 is True). 75 | filter_fn: Callable[[str, str], bool] 76 | A function that takes the path and file ID and returns whether to download 77 | the file. 78 | 79 | Returns 80 | ------- 81 | List[str] 82 | A list of the downloaded files. 83 | """ 84 | # Get the appropriate map function 85 | if self.multiprocessing not in [0, 1]: 86 | pool_count = ( 87 | self.multiprocessing if self.multiprocessing > 0 else os.cpu_count() 88 | ) 89 | pool = mp.Pool(pool_count) 90 | map_fn = pool.map 91 | else: 92 | map_fn = map 93 | 94 | # Filter the file ID mapping 95 | filtered_path_dict = self.filter(file_id_mapping, filter_fn=filter_fn) 96 | 97 | # Set total for logging 98 | self._total = len(filtered_path_dict) 99 | 100 | # Download the files 101 | print( 102 | f"Started downloading at " 103 | f"{datetime.datetime.now().strftime(self.datetime_format)}" 104 | ) 105 | output = list(map_fn(self.download, enumerate(filtered_path_dict.items()))) 106 | print( 107 | f"Finished downloading at " 108 | f"{datetime.datetime.now().strftime(self.datetime_format)}" 109 | ) 110 | 111 | # Clean up multiprocessing 112 | if self.multiprocessing: 113 | pool.close() 114 | pool.join() 115 | return output 116 | 117 | def filter(self, file_id_mapping, filter_fn=lambda x, y: True): 118 | """Filter a file ID mapping. 119 | 120 | Parameters 121 | ---------- 122 | file_id_mapping: Mapping[str, Mapping[str, Any]] 123 | A mapping from the path (relative to self.download_path) to save the file 124 | and another mapping containing at least 'id' as a key and 'md5' as a key 125 | (only necessary if self.check_md5 is True). 126 | filter_fn: Callable[[str, str], bool] 127 | A function that takes the path and file ID and returns whether to download 128 | the file. 129 | 130 | Returns 131 | ------- 132 | Mapping[str, str] 133 | The filtered file ID mapping. 134 | """ 135 | return {k: v for k, v in file_id_mapping.items() if filter_fn(k, v)} 136 | 137 | def compare_checksum(self, filepath, checksum): 138 | with open(filepath, 'rb') as fp: 139 | return hashlib.md5(fp.read()).hexdigest() == checksum 140 | 141 | def download(self, data): 142 | """Download a file. 143 | 144 | Parameters 145 | ---------- 146 | data: Tuple[int, Tuple[str, str]] 147 | The index of the file and a tuple containing the relative path of the file, 148 | and a mapping containing at least a key 'id' for the file ID and a key 149 | 'md5' for the MD5 checksum (only necessary if self.check_md5 is True). 150 | 151 | Returns 152 | ------- 153 | str 154 | The path to the downloaded file. 155 | """ 156 | index, (path, mapping) = data 157 | filepath = os.path.join(self.download_path, path) 158 | print_preamble = ( 159 | f"({index+1}/{self._total}) | " 160 | if self.multiprocessing not in [0, 1] 161 | else "" 162 | ) 163 | 164 | if os.path.exists(filepath) and not self.overwrite: 165 | if self.check_md5: 166 | checksum_comparison = self.compare_checksum(filepath, mapping['md5']) 167 | if checksum_comparison: 168 | print( 169 | f"{print_preamble}{filepath} already exists and has the same " 170 | f"checksum, skipping (set overwrite=True to overwrite)" 171 | ) 172 | return filepath 173 | else: 174 | print( 175 | f"{print_preamble}{filepath} already exists but has a" 176 | f" different checksum, overwriting" 177 | ) 178 | else: 179 | print( 180 | f"{print_preamble}{filepath} already exists," 181 | f"skipping (set overwrite=True to overwrite)" 182 | ) 183 | return filepath 184 | os.makedirs(os.path.dirname(filepath), exist_ok=True) 185 | print(f"{print_preamble}Downloading {path} to {filepath}...", end=" ") 186 | urllib.request.urlretrieve(self.get_url(mapping['id']), filename=filepath) 187 | 188 | extra_msg = "" 189 | if self.check_md5: 190 | checksum_comparison = self.compare_checksum(filepath, mapping['md5']) 191 | if not checksum_comparison: 192 | raise ValueError( 193 | f"Checksum of {filepath} does not match the expected checksum." 194 | ) 195 | else: 196 | extra_msg = " (checksum matches)" 197 | print(f"Done{extra_msg}") 198 | return filepath 199 | 200 | 201 | class DataverseParser: 202 | """Parse a Dataverse dataset for file IDs.""" 203 | 204 | def __init__(self, server): 205 | """Create a new DataverseParser. 206 | 207 | Parameters 208 | ---------- 209 | server: str 210 | The hostname of the server to download the files from. 211 | """ 212 | self.server = server 213 | 214 | def get_url(self, dataset_id): 215 | """Get the URL to get the dataset information from. 216 | 217 | Parameters 218 | ---------- 219 | dataset_id: str 220 | The DOI of the requested dataset. 221 | 222 | Returns 223 | ------- 224 | str 225 | The URL to get the dataset information from. 226 | """ 227 | return f"https://{self.server}/api/datasets/:persistentId/" \ 228 | f"?persistentId={dataset_id}" 229 | 230 | def __call__(self, dataset_id): 231 | """Create a mapping between the relative path to the file and the file ID. 232 | 233 | Parameters 234 | ---------- 235 | dataset_id: str 236 | The DOI of the requested dataset. 237 | 238 | Returns 239 | ------- 240 | Mapping[str, str] 241 | A mapping between the relative path to the file in the dataset and the 242 | file ID. 243 | """ 244 | # Get the dataset information 245 | url = self.get_url(dataset_id) 246 | print(f"Loading data from {url}") 247 | raw_info = urllib.request.urlopen(url).read().decode("utf-8") 248 | info = json.loads(raw_info) 249 | version_info = info["data"]["latestVersion"] 250 | print( 251 | f'Parsing data for version: ' 252 | f'{version_info["versionNumber"]}.{version_info["versionMinorNumber"]}.' 253 | ) 254 | # Parse the files 255 | file_id_mapping = {} 256 | for file_info in version_info["files"]: 257 | path = os.path.join( 258 | file_info.get("directoryLabel", ""), file_info["dataFile"]["filename"] 259 | ) 260 | file_id_mapping[path] = { 261 | 'id': file_info["dataFile"]["id"], 262 | 'md5': file_info["dataFile"]["md5"], 263 | } 264 | return file_id_mapping 265 | -------------------------------------------------------------------------------- /technical_validation/experiments/regression_linear_backward_model.py: -------------------------------------------------------------------------------- 1 | """Example experiment for a linear baseline method.""" 2 | 3 | 4 | import sys 5 | import argparse 6 | 7 | 8 | import numpy as np 9 | import glob 10 | import json 11 | import logging 12 | import os 13 | import scipy.stats 14 | 15 | from technical_validation.util.dataset_generator import RegressionDataGenerator, create_tf_dataset 16 | 17 | 18 | 19 | def time_lag_matrix(input_, tmin, tmax): 20 | """Create a time-lag matrix from a 2D numpy array. 21 | 22 | Parameters 23 | ---------- 24 | eeg: np.ndarray 25 | 2D numpy array with shape (n_samples, n_channels) 26 | num_lags: int 27 | Number of time lags to use. 28 | 29 | Returns 30 | ------- 31 | np.ndarray 32 | 2D numpy array with shape (n_samples, n_channels* num_lags) 33 | """ 34 | # Create a time-lag matrix 35 | numChannels = input_.shape[1] 36 | 37 | final_array = np.zeros((input_.shape[0], numChannels * (tmax - tmin))) 38 | 39 | for index, shift in enumerate(range(tmin, tmax)): 40 | # roll the array to the right 41 | shifted_data = np.roll(input_, -shift, axis=0) 42 | final_array[:, index * numChannels: (index + 1) * numChannels] = shifted_data 43 | 44 | if tmin < 0: 45 | return final_array[np.abs(tmin):-tmax+1, :] 46 | else: 47 | return final_array[:-tmax+1, :] 48 | 49 | 50 | def train_model_cov(cxx, cxy, ridge_param): 51 | return np.linalg.solve(cxx + ridge_param * np.eye(cxx.shape[0]), cxy) 52 | 53 | 54 | def evaluate_model(model, test_eeg, test_env): 55 | pred_env = np.matmul(test_eeg, model) 56 | return scipy.stats.pearsonr(pred_env[:, 0], test_env[:, 0])[0] 57 | 58 | def permutation_test(model, eeg, env, tmin, tmax, numPermutations=100): 59 | pred_env = np.matmul(eeg, model) 60 | corrs = [] 61 | for permutation_index in range(numPermutations): 62 | print(f'Permutation {permutation_index+1:03d}\r', end='') 63 | random_shift = np.random.randint(tmax-tmin, env.shape[0] - (tmax-tmin)) 64 | temp_env = np.roll(env, random_shift, axis=0) 65 | # temp_eeg = temp_eeg[random_shift:, :] 66 | # temp_pred_eeg = pred_eeg[:-random_shift, :] 67 | 68 | corrs.append(scipy.stats.pearsonr(temp_env[:, 0], pred_env[:, 0])[0]) 69 | 70 | print() 71 | return np.array(corrs) 72 | 73 | 74 | def crossval_over_recordings(all_data, tmin, tmax, ridge_param): 75 | # Cross validation loop to determine the optimal ridge parameter 76 | fold_scores = [] 77 | for fold in range(len(all_data)): 78 | print(f'fold {fold}...') 79 | 80 | # train_folds 81 | train_eeg_folds = [x[0] for i, x in enumerate(all_data) if i != fold] 82 | train_env_folds = [x[1] for i, x in enumerate(all_data) if i != fold] 83 | 84 | # test_fold 85 | test_eeg_fold = [x[0] for i, x in enumerate(all_data) if i == fold][0] 86 | test_env_fold = [x[1] for i, x in enumerate(all_data) if i == fold][0] 87 | 88 | # create the model 89 | # closed-form solution, 90 | train_eegs = [time_lag_matrix(eeg, tmin, tmax) for eeg in train_eeg_folds] 91 | cxx = np.sum([np.matmul(x.T, x) for x in train_eegs], axis=0) 92 | cxy = np.sum([np.matmul(x.T, y[:-(tmax - tmin) + 1, :]) for x, y in zip(train_eegs, train_env_folds)], axis=0) 93 | 94 | if not isinstance(ridge_param, (int, float)): 95 | ridge_scores = [] 96 | for lambd in ridge_param: 97 | model = train_model_cov(cxx, cxy, lambd) 98 | # evaluate the model on the test set 99 | score = evaluate_model(model,time_lag_matrix(test_eeg_fold, tmin, tmax), test_env_fold[:-(tmax - tmin) + 1, :]) 100 | ridge_scores.append(score) 101 | fold_scores.append(ridge_scores) 102 | else: 103 | model = train_model_cov(cxx, cxy, ridge_param) 104 | score = evaluate_model(model, time_lag_matrix(test_eeg_fold, tmin, tmax), test_env_fold[:-(tmax - tmin) + 1, :]) 105 | fold_scores.append(score) 106 | return fold_scores 107 | 108 | def training_loop(subject, data_folder, features, highpass, lowpass, tmin, tmax, ridge_param): 109 | print(f"Training model for subject {subject}") 110 | 111 | train_files = [x for x in glob.glob(os.path.join(data_folder, "train_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features and subject in x] 112 | train_files = [x for x in train_files if 'audiobook_15_' not in x] 113 | train_generator = RegressionDataGenerator(train_files, high_pass_freq=highpass, low_pass_freq=lowpass) 114 | all_data = [x for x in train_generator] 115 | 116 | # Leave-one-out cross validation based on number of recordings 117 | # Done to determine the optimal ridge parameter 118 | numFolds = len(all_data) 119 | 120 | fold_scores = crossval_over_recordings(all_data, tmin, tmax, ridge_param) 121 | 122 | if not isinstance(ridge_param, (int, float)): 123 | fold_scores = np.array(fold_scores) 124 | # Take the average across channels and folds to obtain 1 correlation value 125 | # per lambda value 126 | fold_scores = fold_scores.mean() 127 | best_ridge = ridge_param[np.argmax(fold_scores)] 128 | print(f"Best lambda: {best_ridge}") 129 | else: 130 | best_ridge = ridge_param 131 | 132 | 133 | # Actual training of the model on all training folds 134 | train_eegs = [time_lag_matrix(x[0].numpy(), tmin, tmax) for x in all_data] 135 | train_envs = [x[1].numpy() for x in all_data] 136 | cxx = np.sum([np.matmul(x.T, x) for x in train_eegs], axis=0) 137 | cxy = np.sum([np.matmul(x.T, y[:-(tmax-tmin)+1, :]) for x, y in zip(train_eegs, train_envs)], axis=0) 138 | model = train_model_cov(cxx, cxy, best_ridge) 139 | 140 | # # evaluate the model on the test set 141 | test_files = [x for x in glob.glob(os.path.join(data_folder, "test_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features and subject in x] 142 | test_files = [x for x in test_files if 'audiobook_15_' not in x] 143 | test_generator = RegressionDataGenerator(test_files, high_pass_freq=highpass, low_pass_freq=lowpass, return_filenames=True) 144 | 145 | # calculate pearson correlation, per test segment 146 | # and average over all test segments 147 | test_info = {'subject': subject, 'stim_filename':[], 'score': [], 'null_distr': [], 'ridge_param': best_ridge, 'model_weights': model.tolist(), 'highpass':highpass, 'lowpass':lowpass, 'numFolds':numFolds} 148 | print(f"Testing model on test data... ({numFolds})") 149 | for test_filenames, test_seg in test_generator: 150 | test_eeg = test_seg[0] 151 | test_env = test_seg[1] 152 | stim_filename = os.path.basename(test_filenames[1]).split("_-_")[-2] 153 | 154 | test_eeg = time_lag_matrix(test_eeg, tmin, tmax) 155 | # shorten eeg to match the length of the env 156 | test_env = test_env[:-(tmax-tmin)+1, :] 157 | # predict 158 | pearson_scores = evaluate_model(model, test_eeg, test_env) 159 | 160 | test_info['stim_filename'].append(stim_filename) 161 | test_info['score'].append(pearson_scores) 162 | # null distribution 163 | null_distr = permutation_test(model, test_eeg, test_env, tmin, tmax).tolist() 164 | test_info['null_distr'].append(null_distr) 165 | test_info['mean_score'] = np.mean(test_info['score']) 166 | test_info['95_percentile'] = np.percentile(test_info['null_distr'], 95) 167 | 168 | return test_info 169 | 170 | 171 | 172 | if __name__ == "__main__": 173 | # Parameters 174 | 175 | # frequency band chosen for the experiment 176 | # delta (0.5 -4 ) 177 | # theta (4 - 8) 178 | # alpha (8 - 14) 179 | # beta (14 - 30) 180 | # broadband (0.5 - 32) 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument('--highpass', type=float, default=None) 183 | parser.add_argument('--lowpass', type=float, default=4) 184 | 185 | args = parser.parse_args() 186 | highpass = args.highpass 187 | lowpass = args.lowpass 188 | 189 | for highpass, lowpass in [(None, 4), (4, 8), (8, 14), (14, 30), (None, None)]: 190 | numChannels = 64 191 | tmin = -np.round(0.1*64).astype(int) # -100 ms 192 | tmax = np.round(0.4*64).astype(int) # 400 ms 193 | ridge_param = [10**x for x in range(-6, 7, 2)] 194 | overwrite = False 195 | 196 | results_filename = 'eval_filter_{subject}_{tmin}_{tmax}_{highpass}_{lowpass}.json' 197 | 198 | # Get the path to the config gile 199 | experiments_folder = os.path.dirname(__file__) 200 | main_folder = os.path.dirname(os.path.dirname(experiments_folder)) 201 | config_path = os.path.join(main_folder, 'config.json') 202 | 203 | # Load the config 204 | with open(config_path) as fp: 205 | config = json.load(fp) 206 | 207 | # Provide the path of the dataset 208 | # which is split already to train, val, test 209 | 210 | data_folder = os.path.join(config["dataset_folder"], config["derivatives_folder"], config["split_folder"]) 211 | features = ["envelope", "eeg"] 212 | 213 | # Create a directory to store (intermediate) results 214 | results_folder = os.path.join(experiments_folder, "results_linear_backward") 215 | os.makedirs(results_folder, exist_ok=True) 216 | 217 | 218 | # get all the subjects 219 | all_files = glob.glob(os.path.join(data_folder, "train_-_*")) 220 | subjects = list(set([os.path.basename(x).split("_-_")[1] for x in all_files])) 221 | 222 | evaluation_all_subs = {} 223 | chance_level_all_subs = {} 224 | 225 | # train one model per subject 226 | for subject in subjects: 227 | save_path = os.path.join(results_folder, results_filename.format(subject=subject, tmin=tmin, tmax=tmax, highpass=highpass, lowpass=lowpass)) 228 | if not os.path.exists(save_path) or overwrite: 229 | result = training_loop(subject, data_folder, features, highpass, lowpass, tmin, tmax, ridge_param) 230 | 231 | # save the results 232 | with open(save_path, 'w') as fp: 233 | json.dump(result, fp) 234 | else: 235 | print(f"Results for {subject} already exist, skipping...") 236 | 237 | 238 | 239 | 240 | 241 | 242 | -------------------------------------------------------------------------------- /technical_validation/experiments/regression_linear_forward_model.py: -------------------------------------------------------------------------------- 1 | """Example experiment for a linear baseline method.""" 2 | 3 | 4 | import sys 5 | import argparse 6 | 7 | 8 | import numpy as np 9 | import glob 10 | import json 11 | import logging 12 | import os 13 | import scipy.stats 14 | 15 | from technical_validation.util.dataset_generator import RegressionDataGenerator, create_tf_dataset 16 | 17 | 18 | 19 | def time_lag_matrix(input_, tmin, tmax): 20 | """Create a time-lag matrix from a 2D numpy array. 21 | 22 | Parameters 23 | ---------- 24 | eeg: np.ndarray 25 | 2D numpy array with shape (n_samples, n_channels) 26 | num_lags: int 27 | Number of time lags to use. 28 | 29 | Returns 30 | ------- 31 | np.ndarray 32 | 2D numpy array with shape (n_samples, n_channels* num_lags) 33 | """ 34 | # Create a time-lag matrix 35 | numChannels = input_.shape[1] 36 | 37 | final_array = np.zeros((input_.shape[0], numChannels * (tmax - tmin))) 38 | 39 | for index, shift in enumerate(range(tmin, tmax)): 40 | # roll the array to the right 41 | shifted_data = np.roll(input_, -shift, axis=0) 42 | final_array[:, index * numChannels: (index + 1) * numChannels] = shifted_data 43 | 44 | if tmin < 0: 45 | return final_array[np.abs(tmin):-tmax+1, :] 46 | else: 47 | return final_array[:-tmax+1, :] 48 | 49 | 50 | def train_model_cov(cxx, cxy, ridge_param): 51 | return np.linalg.solve(cxx + ridge_param * np.eye(cxx.shape[0]), cxy) 52 | 53 | def evaluate_model(model, test_env, test_eeg): 54 | pred_eeg = np.matmul(test_env, model) 55 | channel_scores = [] 56 | for channel in range(test_eeg.shape[1]): 57 | score = scipy.stats.pearsonr(pred_eeg[:, channel], test_eeg[:, channel])[0] 58 | channel_scores.append(score.tolist()) 59 | return channel_scores 60 | 61 | def permutation_test(model, eeg, env, tmin, tmax, numPermutations=100): 62 | pred_eeg = np.matmul(env, model) 63 | corrs = [] 64 | for permutation_index in range(numPermutations): 65 | print(f'Permutation {permutation_index+1:03d}\r', end='') 66 | random_shift = np.random.randint(tmax-tmin, eeg.shape[0] - (tmax-tmin)) 67 | temp_eeg = np.roll(eeg, random_shift, axis=0) 68 | # temp_eeg = temp_eeg[random_shift:, :] 69 | # temp_pred_eeg = pred_eeg[:-random_shift, :] 70 | temp_pred_eeg = pred_eeg 71 | channel_corrs = [] 72 | for channel in range(eeg.shape[1]): 73 | channel_corrs.append(scipy.stats.pearsonr(temp_eeg[:, channel], temp_pred_eeg[:, channel])[0]) 74 | corrs.append(channel_corrs) 75 | print() 76 | return np.array(corrs) 77 | 78 | def crossval_over_recordings(all_data, tmin, tmax, ridge_param): 79 | # Cross validation loop to determine the optimal ridge parameter 80 | fold_scores = [] 81 | for fold in range(len(all_data)): 82 | print(f'fold {fold}...') 83 | 84 | # train_folds 85 | train_eeg_folds = [x[0] for i, x in enumerate(all_data) if i != fold] 86 | train_env_folds = [x[1] for i, x in enumerate(all_data) if i != fold] 87 | 88 | # test_fold 89 | test_eeg_fold = [x[0] for i, x in enumerate(all_data) if i == fold][0] 90 | test_env_fold = [x[1] for i, x in enumerate(all_data) if i == fold][0] 91 | 92 | # create the model 93 | # closed-form solution, 94 | train_envs = [time_lag_matrix(env, tmin, tmax) for env in train_env_folds] 95 | cxx = np.sum([np.matmul(x.T, x) for x in train_envs], axis=0) 96 | cxy = np.sum([np.matmul(x.T, y[:-(tmax - tmin) + 1, :]) for x, y in 97 | zip(train_envs, train_eeg_folds)], axis=0) 98 | 99 | if not isinstance(ridge_param, (int, float)): 100 | ridge_scores = [] 101 | for lambd in ridge_param: 102 | model = train_model_cov(cxx, cxy, lambd) 103 | # evaluate the model on the test set 104 | score = evaluate_model(model, 105 | time_lag_matrix(test_env_fold, tmin, tmax), 106 | test_eeg_fold[:-(tmax - tmin) + 1, :]) 107 | ridge_scores.append(score) 108 | fold_scores.append(ridge_scores) 109 | else: 110 | model = train_model_cov(cxx, cxy, ridge_param) 111 | score = evaluate_model(model, time_lag_matrix(test_env_fold, tmin, tmax), 112 | test_eeg_fold[:-(tmax - tmin) + 1, :]) 113 | fold_scores.append(score) 114 | return fold_scores 115 | 116 | def training_loop(subject, data_folder, features, highpass, lowpass, tmin, tmax, ridge_param): 117 | print(f"Training model for subject {subject}") 118 | 119 | train_files = [x for x in glob.glob(os.path.join(data_folder, "train_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features and subject in x] 120 | train_files = [x for x in train_files if 'audiobook_15_' not in x] 121 | train_generator = RegressionDataGenerator(train_files, high_pass_freq=highpass, low_pass_freq=lowpass) 122 | all_data = [x for x in train_generator] 123 | 124 | # Leave-one-out cross validation based on number of recordings 125 | # Done to determine the optimal ridge parameter 126 | numFolds = len(all_data) 127 | 128 | fold_scores = crossval_over_recordings(all_data, tmin, tmax, ridge_param) 129 | 130 | if not isinstance(ridge_param, (int, float)): 131 | fold_scores = np.array(fold_scores) 132 | # Take the average across channels and folds to obtain 1 correlation value 133 | # per lambda value 134 | fold_scores = fold_scores.mean(axis=2).mean(axis=0) 135 | best_ridge = ridge_param[np.argmax(fold_scores)] 136 | print(f"Best lambda: {best_ridge}") 137 | else: 138 | best_ridge = ridge_param 139 | 140 | 141 | # Actual training of the model on all training folds 142 | train_eegs = [x[0] for x in all_data] 143 | train_envs = [time_lag_matrix(x[1], tmin, tmax) for x in all_data] 144 | cxx = np.sum([np.matmul(x.T, x) for x in train_envs], axis=0) 145 | cxy = np.sum([np.matmul(x.T, y[:-(tmax-tmin)+1, :]) for x, y in zip(train_envs, train_eegs)], axis=0) 146 | model = train_model_cov(cxx, cxy, best_ridge) 147 | 148 | # # evaluate the model on the test set 149 | test_files = [x for x in glob.glob(os.path.join(data_folder, "test_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features and subject in x] 150 | test_files = [x for x in test_files if 'audiobook_15_' not in x] 151 | test_generator = RegressionDataGenerator(test_files, high_pass_freq=highpass, low_pass_freq=lowpass, return_filenames=True) 152 | 153 | # calculate pearson correlation, per test segment 154 | # and average over all test segments 155 | test_info = {'subject': subject, 'stim_filename':[], 'score': [], 'null_distr': [], 'ridge_param': best_ridge, 'model_weights': model.tolist(), 'highpass':highpass, 'lowpass':lowpass, 'numFolds':numFolds} 156 | print(f"Testing model on test data... ({numFolds})") 157 | for test_filenames, test_seg in test_generator: 158 | test_eeg = test_seg[0] 159 | test_env = test_seg[1] 160 | stim_filename = os.path.basename(test_filenames[1]).split("_-_")[-2] 161 | 162 | test_env = time_lag_matrix(test_env, tmin, tmax) 163 | # shorten eeg to match the length of the env 164 | test_eeg = test_eeg[:-(tmax-tmin)+1, :] 165 | # predict 166 | pearson_scores = evaluate_model(model, test_env, test_eeg) 167 | 168 | test_info['stim_filename'].append(stim_filename) 169 | test_info['score'].append(pearson_scores) 170 | # null distribution 171 | null_distr = permutation_test(model, test_eeg, test_env, tmin, tmax).tolist() 172 | test_info['null_distr'].append(null_distr) 173 | test_info['mean_score_per_channel'] = np.mean(test_info['score'], axis=0).tolist() 174 | test_info['mean_score'] = np.mean(test_info['score']) 175 | null_distr = np.reshape(test_info['null_distr'], (-1, 64)) 176 | test_info['95_percentile_per_channel'] = np.percentile(null_distr, 95, axis=0).tolist() 177 | test_info['95_percentile'] = np.percentile(test_info['null_distr'], 95) 178 | 179 | return test_info 180 | 181 | 182 | 183 | if __name__ == "__main__": 184 | # Parameters 185 | 186 | # frequency band chosen for the experiment 187 | # delta (0.5 -4 ) 188 | # theta (4 - 8) 189 | # alpha (8 - 14) 190 | # beta (14 - 30) 191 | # broadband (0.5 - 32) 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument('--highpass', type=float, default=None) 194 | parser.add_argument('--lowpass', type=float, default=4) 195 | 196 | args = parser.parse_args() 197 | highpass = args.highpass 198 | lowpass = args.lowpass 199 | 200 | for highpass, lowpass in [(None, 4), (4, 8), (8, 14), (14, 30), (None, None)]: 201 | 202 | 203 | numChannels = 64 204 | tmin = -np.round(0.1*64).astype(int) # -100 ms 205 | tmax = np.round(0.4*64).astype(int) # 400 ms 206 | ridge_param = [10**x for x in range(-6, 7, 2)] 207 | overwrite = False 208 | 209 | results_filename = 'eval_filter_{subject}_{tmin}_{tmax}_{highpass}_{lowpass}.json' 210 | 211 | # Get the path to the config gile 212 | experiments_folder = os.path.dirname(__file__) 213 | main_folder = os.path.dirname(os.path.dirname(experiments_folder)) 214 | config_path = os.path.join(main_folder, 'config.json') 215 | 216 | # Load the config 217 | with open(config_path) as fp: 218 | config = json.load(fp) 219 | 220 | # Provide the path of the dataset 221 | # which is split already to train, val, test 222 | 223 | data_folder = os.path.join(config["dataset_folder"], config["derivatives_folder"], config["split_folder"]) 224 | features = ["envelope", "eeg"] 225 | 226 | # Create a directory to store (intermediate) results 227 | results_folder = os.path.join(experiments_folder, "results_linear_forward") 228 | os.makedirs(results_folder, exist_ok=True) 229 | 230 | 231 | # get all the subjects 232 | all_files = glob.glob(os.path.join(data_folder, "train_-_*")) 233 | subjects = list(set([os.path.basename(x).split("_-_")[1] for x in all_files])) 234 | 235 | evaluation_all_subs = {} 236 | chance_level_all_subs = {} 237 | 238 | # train one model per subject 239 | for subject in subjects: 240 | save_path = os.path.join(results_folder, results_filename.format(subject=subject, tmin=tmin, tmax=tmax, highpass=highpass, lowpass=lowpass)) 241 | if not os.path.exists(save_path) or overwrite: 242 | result = training_loop(subject, data_folder, features, highpass, lowpass, tmin, tmax, ridge_param) 243 | 244 | # save the results 245 | with open(save_path, 'w') as fp: 246 | json.dump(result, fp) 247 | else: 248 | print(f"Results for {subject} already exist, skipping...") 249 | 250 | 251 | 252 | 253 | 254 | 255 | -------------------------------------------------------------------------------- /technical_validation/util/plot_results.py: -------------------------------------------------------------------------------- 1 | # import seaborn as sns 2 | import glob 3 | import json 4 | import os 5 | 6 | import matplotlib.pyplot as plt 7 | import mne 8 | import numpy as np 9 | import pandas as pd 10 | import scipy.stats 11 | import seaborn as sns 12 | 13 | # generate plots from all the different results and plsave them in the figures folder 14 | 15 | # load the results 16 | base_results_folder = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'experiments') 17 | os.makedirs(os.path.join(base_results_folder, 'figures'), exist_ok=True) 18 | plot_dilation = True 19 | plot_linear_backward = True 20 | plot_linear_forward = True 21 | plot_vlaai = True 22 | 23 | freq_bands = { 24 | 'Delta [0.5-4]': (0.5, 4.0), 25 | 'Theta [4-8]': (4, 8.0), 26 | 'Alpha [8-14]': (8, 14.0), 27 | 'Beta [14-30]': (14, 30.0), 28 | 'Broadband [0.5-32]': (0.5, 32.0), 29 | } 30 | 31 | if plot_dilation: 32 | # dilation model, match mismatch results 33 | # plot boxplot of the results per window length 34 | # load evaluation results for all window lengths 35 | 36 | files = glob.glob(os.path.join(base_results_folder, "results_dilated_convolutional_model/eval_*.json")) 37 | # sort the files 38 | files.sort() 39 | 40 | # create dict to save all results per sub 41 | results = [] 42 | windows = [] 43 | for f in files: 44 | # load the results 45 | with open(f, "rb") as ff: 46 | res = json.load(ff) 47 | #loop over res and get accuracy in a list 48 | acc = [] 49 | for sub, sub_res in res.items(): 50 | 51 | if 'acc' in sub_res: 52 | acc.append(sub_res['acc']*100) 53 | 54 | results.append(acc) 55 | 56 | # get the window length 57 | windows.append(int(f.split("_")[-1].split(".")[0])) 58 | 59 | # sort windows and results according to windows 60 | windows, results = zip(*sorted(zip(windows, results))) 61 | # convert windows to seconds 62 | windows = ['%d' %int(w/64) for w in windows] 63 | 64 | #boxplot of the results 65 | plt.boxplot(results, labels=windows) 66 | plt.xlabel("Window length (s)") 67 | plt.ylabel("Accuracy (%)") 68 | # plt.title("Accuracy of dilation model, per window length") 69 | plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_dilated_conv.pdf")) 70 | 71 | if plot_linear_backward: 72 | ## linear backward regression plots 73 | # load the results 74 | results_files_glob = os.path.join(base_results_folder, "results_linear_backward", "eval*.json") 75 | results_files = glob.glob(results_files_glob) 76 | all_results = [] 77 | 78 | for result_file in results_files: 79 | with open(result_file, "r") as f: 80 | data = json.load(f) 81 | for index, stim in enumerate(data['stim_filename']): 82 | percentile = np.percentile(data['null_distr'][index], 95) 83 | score = data['score'][index] 84 | highpass = data['highpass'] if data['highpass'] is not None else 0.5 85 | lowpass = data['lowpass'] if data['lowpass'] is not None else 32.0 86 | 87 | all_results.append([data['subject'], stim, score, highpass, lowpass, percentile, score > percentile, data['null_distr'][index]]) 88 | 89 | df = pd.DataFrame(all_results, columns=['subject', 'stim', 'score', 'highpass', 'lowpass', 'percentile', 'significant', 'null_distr']) 90 | 91 | print('Confirming that we found neural tracking for each subject') 92 | nb_subjects = len(pd.unique(df['subject'])) 93 | nb_significant = df.groupby('subject').agg({'significant': 'any'}).sum() 94 | print('Found {} subjects, {} of which had at least one significant result'.format(nb_subjects, nb_significant)) 95 | 96 | 97 | subject_stories_group = df.groupby(['subject', 'stim']).agg({'significant': 'any'}) 98 | nb_recordings = len(subject_stories_group) 99 | nb_significant_recordings = subject_stories_group.sum() 100 | print("Found {} recordings, {} of which were significant".format(nb_recordings, nb_significant_recordings)) 101 | non_significant_recordings_series = df.groupby(['subject', 'stim']).agg({'significant': 'any'})['significant'] == False 102 | non_significant_recordings = non_significant_recordings_series[non_significant_recordings_series].index 103 | 104 | # Table of non-significant results 105 | print("Non-significant results:") 106 | for subject, stimulus in non_significant_recordings: 107 | print("{} & {} \\\\".format(subject, stimulus)) 108 | 109 | # plot the results 110 | ## General frequency plot 111 | values_for_boxplot = [] 112 | signficance_level = [] 113 | names_for_boxplot = [] 114 | for band_name, (highpass, lowpass) in freq_bands.items(): 115 | 116 | selected_df = df[(df['highpass'] == highpass) & (df['lowpass'] == lowpass)] 117 | values_for_boxplot.append(selected_df.groupby('subject').agg('mean')['score'].tolist()) 118 | signficance_level.append(np.percentile(selected_df['null_distr'].tolist(), 95)) 119 | 120 | plt.figure(figsize=(8, 5)) 121 | # plt.boxplot(values_for_boxplot, labels=freq_bands.keys()) 122 | temp_df = pd.DataFrame(values_for_boxplot, index=freq_bands.keys(), columns=range(85)).T 123 | sns.violinplot(data=temp_df) 124 | 125 | xs = [np.linspace(x - 0.5, x + 0.5, 200) for x in range(5)] 126 | ys = [[signficance_level[i]]*200 for i in range(5)] 127 | 128 | plt.plot(np.reshape(xs, (-1,)), np.reshape(ys, (-1,)), '--', color='black') 129 | plt.ylabel('Pearson correlation') 130 | plt.xlabel('Frequency band') 131 | #plt.grid(True) 132 | plt.xlim(-0.5, len(values_for_boxplot)-0.5) 133 | plt.title('Linear decoder performance across frequency bands') 134 | # plt.show() 135 | plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_linear_backward_frequency.pdf")) 136 | #plt.close() 137 | 138 | ## plot the results per stimulus 139 | values_for_boxplot = [] 140 | names_for_boxplot = [] 141 | def sort_key_fn(x): 142 | split = x.split('_') 143 | number = ord(split[0][0])*1e6+ int(split[1])*1e3 144 | # artifact audiobook 145 | if len(split) > 2 and split[2].isdigit(): 146 | number += int(split[2]) 147 | return number 148 | 149 | fig, (ax0, ax1) = plt.subplots(1,2, width_ratios=[0.2, 0.8], figsize=(15,6), sharey=True) 150 | 151 | for stimulus in sorted(pd.unique(df['stim']), key=sort_key_fn): 152 | selected_df = df[(df['stim'] == stimulus) &(df['highpass'] == 0.5) & (df['lowpass'] == 4.0)] 153 | values_for_boxplot.append(selected_df.groupby('subject').agg('mean')['score'].tolist()) 154 | names_for_boxplot.append(stimulus + ' ({})'.format(len(selected_df))) 155 | # plt.figure(figsize=(11,6)) 156 | ax1.boxplot(values_for_boxplot, labels=names_for_boxplot) 157 | #ax1.set_ylabel('Pearson correlation') 158 | # ax1.set_xlabel('Stimulus (Number of recordings)') 159 | plt.xticks(rotation=90) 160 | 161 | # plt.grid(True) 162 | ax1.set_title('Across stimuli') 163 | # plt.tight_layout() 164 | # plt.show() 165 | # plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_linear_backward_stimuli.pdf")) 166 | # plt.close() 167 | 168 | 169 | ## plot the results per stimulus type 170 | values_for_boxplot = [] 171 | names_for_boxplot = [] 172 | stim_type_selectors = { 173 | 'Audiobook': df['stim'].str.startswith('audiobook_'), 174 | 'Podcast': df['stim'].str.startswith('podcast_'), 175 | } 176 | test_data = [] 177 | for stimulus_type, stimulus_selector in stim_type_selectors.items(): 178 | selected_df = df[(stimulus_selector) &(df['highpass'] == 0.5) & (df['lowpass'] == 4.0) & (~(df['stim'].str.endswith('artefact'))) & (~(df['stim'].str.endswith('shifted'))) & (~(df['stim'].str.endswith('audiobook_1_1')))& (~(df['stim'].str.endswith('audiobook_1_2')))] 179 | values_for_boxplot.append(selected_df.groupby('subject').agg('mean')['score'].tolist()) 180 | names_for_boxplot.append(stimulus_type + ' ({})'.format(len(selected_df))) 181 | test_data += [selected_df['score'].tolist()] 182 | 183 | print("MannWhitneyU test: {}, medians: {}, {}".format(scipy.stats.mannwhitneyu(test_data[0], test_data[1]),np.median(test_data[0]), np.median(test_data[1]))) 184 | # plt.figure() 185 | ax0.boxplot(values_for_boxplot, labels=names_for_boxplot) 186 | ax0.set_ylabel('Pearson correlation') 187 | # ax0.set_xlabel('\nStimulus type (Number of recordings)') 188 | # plt.xticks() 189 | 190 | #plt.grid(True) 191 | ax0.set_title('Across stimuli types') 192 | fig.suptitle('Linear decoder performance') 193 | fig.supxlabel('Stimulus [type] (Number of recordings)') 194 | 195 | plt.tight_layout() 196 | 197 | # plt.show() 198 | # plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_linear_backward_stimulus_type.pdf")) 199 | plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_linear_backward_stimulus_combined.pdf")) 200 | 201 | # plt.close() 202 | 203 | 204 | if plot_linear_forward: 205 | results_files_glob = os.path.join(base_results_folder, "results_linear_forward", 206 | "eval*.json") 207 | results_files = glob.glob(results_files_glob) 208 | all_results = [] 209 | 210 | for result_file in results_files: 211 | with open(result_file, "r") as f: 212 | data = json.load(f) 213 | for index, stim in enumerate(data['stim_filename']): 214 | percentile = np.percentile(data['null_distr'][index], 95) 215 | score = data['score'][index] 216 | null_distr = data['null_distr'][index] 217 | highpass = data['highpass'] if data['highpass'] is not None else 0.5 218 | lowpass = data['lowpass'] if data['lowpass'] is not None else 32.0 219 | 220 | all_results.append([data['subject'], stim, score, null_distr, highpass, lowpass, percentile, score > percentile, data['null_distr'][index]]) 221 | 222 | df = pd.DataFrame(all_results, columns=['subject', 'stim', 'scores_per_channel', 'null_distr', 'highpass', 'lowpass', 'percentile', 'significant', 'null_distr']) 223 | 224 | montage = mne.channels.make_standard_montage('biosemi64') 225 | sfreq = 64 226 | info = mne.create_info(ch_names=montage.ch_names, sfreq=sfreq, ch_types='eeg').set_montage(montage) 227 | 228 | fig, axes = plt.subplots(3, 5, figsize=(16, 12)) 229 | temp_df = df.copy() 230 | temp_df['filterband'] = temp_df['highpass'].astype(str) + '-' + temp_df['lowpass'].astype(str) 231 | scores = temp_df.groupby(['filterband']).agg({'scores_per_channel': lambda x: list(x)}) 232 | all_scores = np.mean(scores['scores_per_channel'].tolist(), axis=1) 233 | max_coef = np.max(all_scores) 234 | min_coef = np.min(all_scores) 235 | 236 | stim_type_selectors = { 237 | 'All Stimuli': df['stim'].str.startswith(''), 238 | 'Audiobook': df['stim'].str.startswith('audiobook_'), 239 | 'Podcast': df['stim'].str.startswith('podcast_'), 240 | 241 | } 242 | for index, (stim_type, selector) in enumerate(stim_type_selectors.items()): 243 | 244 | axes[index][0].set_ylabel(stim_type) 245 | 246 | for index2, (band_name, (highpass, lowpass)) in enumerate(freq_bands.items()): 247 | ax = axes[index][index2] 248 | selected_df = df[(df['highpass'] == highpass) & (df['lowpass'] == lowpass) & (selector)] 249 | scores_per_subject = selected_df.groupby('subject').agg({'scores_per_channel': lambda x: np.mean(np.stack(x,axis=1), axis=1)}) 250 | mean_scores = np.mean(scores_per_subject.to_numpy(), axis=0).tolist()[0] 251 | # percentile = np.percentile(np.concatenate(selected_df['null_distr'].tolist(), axis=0), 95, axis=0) 252 | # plot the topoplot 253 | im , cn = mne.viz.plot_topomap(mean_scores, pos=info, axes=ax ,show=False, cmap='Reds', vlim=(min_coef,max_coef)) 254 | mne.viz.tight_layout() 255 | ax.set_title(f"{band_name} Hz") 256 | 257 | # cbar_ax = fig.add_axes([0.95, 0.15, 0.05, 0.7]) 258 | 259 | 260 | 261 | #plt.colorbar(im, cax=axes[5], label='Pearson correlation') #cax=cbar_ax, 262 | fig.suptitle('Forward model performance across stimuli types and frequency bands',y=0.94, fontsize=18) 263 | fig.tight_layout() 264 | 265 | fig.subplots_adjust(right=0.90) 266 | cbar_ax = fig.add_axes([0.92, 0.11, 0.02, 0.77]) 267 | fig.colorbar(im, cax=cbar_ax, label='Pearson correlation') 268 | plt.savefig(os.path.join(base_results_folder, 'figures', f"topoplot_linear_forward.pdf")) 269 | #plt.show() 270 | plt.close(fig) 271 | 272 | 273 | 274 | 275 | 276 | 277 | if plot_vlaai: 278 | 279 | ## vlaai results 280 | # load the results 281 | results_files = glob.glob(os.path.join(base_results_folder, "results_vlaai/eval*.json")) 282 | 283 | results = {} 284 | for f in results_files: 285 | 286 | with open(f, "rb") as ff: 287 | res = json.load(ff) 288 | #loop over res and get accuracy in a list 289 | acc = [] 290 | subs = [] 291 | for sub, sub_res in res.items(): 292 | if 'pearson_metric' in sub_res: 293 | acc.append(sub_res['pearson_metric']) 294 | subs.append(sub) 295 | 296 | results = acc 297 | 298 | # plot the results 299 | plt.figure() 300 | plt.boxplot(results, labels=['vlaai']) 301 | plt.xlabel("Model") 302 | plt.ylabel("Correlation") 303 | # plt.title("Correlation, per subject") 304 | plt.savefig(os.path.join(base_results_folder, 'figures', "boxplot_vlaai.pdf")) 305 | 306 | 307 | -------------------------------------------------------------------------------- /technical_validation/util/dataset_generator.py: -------------------------------------------------------------------------------- 1 | """Code for the dataset_generator for task1.""" 2 | import itertools 3 | import os 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | import scipy.signal 8 | 9 | @tf.function 10 | def default_batch_equalizer_fn(*args): 11 | """Batch equalizer. 12 | Prepares the inputs for a model to be trained in 13 | match-mismatch task. It makes sure that match_env 14 | and mismatch_env are equally presented as a first 15 | envelope in match-mismatch task. 16 | 17 | Parameters 18 | ---------- 19 | args : Sequence[tf.Tensor] 20 | List of tensors representing feature data 21 | 22 | Returns 23 | ------- 24 | Tuple[Tuple[tf.Tensor], tf.Tensor] 25 | Tuple of the EEG/speech features serving as the input to the model and 26 | the labels for the match/mismatch task 27 | 28 | Notes 29 | ----- 30 | This function will also double the batch size. E.g. if the batch size of 31 | the elements in each of the args was 32, the output features will have 32 | a batch size of 64. 33 | """ 34 | eeg = args[0] 35 | new_eeg = tf.concat([eeg, eeg], axis=0) 36 | all_features = [new_eeg] 37 | for match, mismatch in zip(args[1::2], args[2::2]): 38 | stimulus_feature1 = tf.concat([match, mismatch], axis=0) 39 | stimulus_feature2 = tf.concat([mismatch, match], axis=0) 40 | all_features += [stimulus_feature1, stimulus_feature2] 41 | labels = tf.concat( 42 | [ 43 | tf.tile(tf.constant([[0]]), [tf.shape(eeg)[0], 1]), 44 | tf.tile(tf.constant([[1]]), [tf.shape(eeg)[0], 1]), 45 | ], 46 | axis=0, 47 | ) 48 | 49 | # print(new_eeg.shape, env1.shape, env2.shape, labels.shape) 50 | return tuple(all_features), labels 51 | 52 | 53 | def create_tf_dataset( 54 | data_generator, 55 | window_length, 56 | batch_equalizer_fn=None, 57 | hop_length=64, 58 | batch_size=64, 59 | data_types=(tf.float32, tf.float32, tf.float32), 60 | feature_dims=(64, 1, 1), 61 | eeg_first = True # if eeg_first is True, the first feature is eeg, otherwise the first feature is the stimulus 62 | ): 63 | """Creates a tf.data.Dataset. 64 | 65 | This will be used to create a dataset generator that will 66 | pass windowed data to a model in both tasks. 67 | 68 | Parameters 69 | --------- 70 | data_generator: DataGenerator 71 | A data generator. 72 | window_length: int 73 | Length of the decision window in samples. 74 | batch_equalizer_fn: Callable 75 | Function that will be applied on the data after batching (using 76 | the `map` method from tf.data.Dataset). In the match/mismatch task, 77 | this function creates the imposter segments and labels. 78 | hop_length: int 79 | Hop length between two consecutive decision windows. 80 | batch_size: Optional[int] 81 | If not None, specifies the batch size. In the match/mismatch task, 82 | this amount will be doubled by the default_batch_equalizer_fn 83 | data_types: Union[Sequence[tf.dtype], tf.dtype] 84 | The data types that the individual features of data_generator should 85 | be cast to. If you only specify a single datatype, it will be chosen 86 | for all EEG/speech features. 87 | 88 | Returns 89 | ------- 90 | tf.data.Dataset 91 | A Dataset object that generates data to train/evaluate models 92 | efficiently 93 | """ 94 | # create tf dataset from generator 95 | dataset = tf.data.Dataset.from_generator( 96 | data_generator, 97 | output_signature=tuple( 98 | tf.TensorSpec(shape=(None, x), dtype=data_types[index]) 99 | for index, x in enumerate(feature_dims) 100 | ), 101 | ) 102 | # if forward model, swap the order of the features 103 | if not eeg_first: 104 | dataset = dataset.map(lambda eeg, speech: (speech, eeg)) 105 | 106 | # window dataset 107 | dataset = dataset.map( 108 | lambda *args: [ 109 | tf.signal.frame(arg, window_length, hop_length, axis=0) 110 | for arg in args 111 | ] 112 | ) 113 | 114 | # batch data 115 | dataset = dataset.interleave( 116 | lambda *args: tf.data.Dataset.from_tensor_slices(args), 117 | cycle_length=4, 118 | block_length=16, 119 | ) 120 | if batch_size is not None: 121 | dataset = dataset.batch(batch_size, drop_remainder=True) 122 | 123 | if batch_equalizer_fn is not None: 124 | # Create the labels and make sure classes are balanced 125 | dataset = dataset.map(batch_equalizer_fn) 126 | 127 | return dataset 128 | 129 | 130 | 131 | class MatchMismatchDataGenerator: 132 | """Generate data for the Match/Mismatch task.""" 133 | 134 | def __init__( 135 | self, 136 | files, 137 | window_length, 138 | spacing 139 | ): 140 | """Initialize the DataGenerator. 141 | 142 | Parameters 143 | ---------- 144 | files: Sequence[Union[str, pathlib.Path]] 145 | Files to load. 146 | window_length: int 147 | Length of the decision window. 148 | spacing: int 149 | Spacing between matched and mismatched samples 150 | """ 151 | self.window_length = window_length 152 | self.files = self.group_recordings(files) 153 | self.spacing = spacing 154 | 155 | def group_recordings(self, files): 156 | """Group recordings and corresponding stimuli. 157 | 158 | Parameters 159 | ---------- 160 | files : Sequence[Union[str, pathlib.Path]] 161 | List of filepaths to preprocessed and split EEG and speech features 162 | 163 | Returns 164 | ------- 165 | list 166 | Files grouped by the self.group_key_fn and subsequently sorted 167 | by the self.feature_sort_fn. 168 | """ 169 | new_files = [] 170 | grouped = itertools.groupby(sorted(files), lambda x: "_-_".join(os.path.basename(x).split("_-_")[:3])) 171 | for recording_name, feature_paths in grouped: 172 | new_files += [sorted(feature_paths, key=lambda x: "0" if x == "eeg" else x)] 173 | return new_files 174 | 175 | def __len__(self): 176 | return len(self.files) 177 | 178 | def __getitem__(self, recording_index): 179 | """Get data for a certain recording. 180 | 181 | Parameters 182 | ---------- 183 | recording_index: int 184 | Index of the recording in this dataset 185 | 186 | Returns 187 | ------- 188 | Union[Tuple[tf.Tensor,...], Tuple[np.ndarray,...]] 189 | The features corresponding to the recording_index recording 190 | """ 191 | data = [] 192 | for feature in self.files[recording_index]: 193 | data += [np.load(feature).astype(np.float32)] 194 | data = self.prepare_data(data) 195 | return tuple(tf.constant(x) for x in data) 196 | 197 | 198 | def __call__(self): 199 | """Load data for the next recording. 200 | 201 | Yields 202 | ------- 203 | Union[Tuple[tf.Tensor,...], Tuple[np.ndarray,...]] 204 | The features corresponding to the recording_index recording 205 | """ 206 | for idx in range(self.__len__()): 207 | yield self.__getitem__(idx) 208 | 209 | if idx == self.__len__() - 1: 210 | self.on_epoch_end() 211 | 212 | def on_epoch_end(self): 213 | """Change state at the end of an epoch.""" 214 | np.random.shuffle(self.files) 215 | 216 | def prepare_data(self, data): 217 | """Creates mismatch (imposter) envelope. 218 | 219 | Parameters 220 | ---------- 221 | data: Sequence[numpy.ndarray] 222 | Data to create an imposter for. 223 | 224 | Returns 225 | ------- 226 | tuple (numpy.ndarray, numpy.ndarray, numpy.ndarray, ...) 227 | (EEG, matched stimulus feature, mismatched stimulus feature, ...). 228 | """ 229 | eeg = data[0] 230 | new_length = eeg.shape[0] - self.window_length - self.spacing 231 | resulting_data = [eeg[:new_length, ...]] 232 | for stimulus_feature in data[1:]: 233 | match_feature = stimulus_feature[:new_length, ...] 234 | mismatch_feature = stimulus_feature[ 235 | self.spacing + self.window_length:, ... 236 | ] 237 | resulting_data += [match_feature, mismatch_feature] 238 | return resulting_data 239 | 240 | 241 | 242 | class RegressionDataGenerator: 243 | """Generate data for the regression task.""" 244 | 245 | def __init__( 246 | self, 247 | files, 248 | window_length= None, 249 | high_pass_freq= None, 250 | low_pass_freq= None, 251 | return_filenames=False, 252 | ): 253 | """Initialize the DataGenerator. 254 | 255 | Parameters 256 | ---------- 257 | files: Sequence[Union[str, pathlib.Path]] 258 | Files to load. 259 | window_length: int 260 | Length of the decision window. 261 | """ 262 | self.files = self.group_recordings(files) 263 | self.return_filenames = return_filenames 264 | self.high_pass_freq = high_pass_freq 265 | self.low_pass_freq = low_pass_freq 266 | 267 | if self.high_pass_freq and self.low_pass_freq: 268 | self.filter_ = scipy.signal.butter(N= 1, 269 | Wn =[self.high_pass_freq, self.low_pass_freq], 270 | btype= "bandpass", 271 | fs=64, 272 | output="sos") 273 | if self.high_pass_freq and not self.low_pass_freq: 274 | self.filter_ = scipy.signal.butter(N= 1, 275 | Wn = self.high_pass_freq, 276 | btype= "highpass", 277 | fs=64, 278 | output="sos") 279 | if not self.high_pass_freq and self.low_pass_freq: 280 | self.filter_ = scipy.signal.butter(N= 1, 281 | Wn = self.low_pass_freq, 282 | btype= "lowpass", 283 | fs=64, 284 | output="sos") 285 | if not self.high_pass_freq and not self.low_pass_freq: 286 | self.filter_ = None 287 | 288 | 289 | def group_recordings(self, files): 290 | """Group recordings and corresponding stimuli. 291 | 292 | Parameters 293 | ---------- 294 | files : Sequence[Union[str, pathlib.Path]] 295 | List of filepaths to preprocessed and split EEG and speech features 296 | 297 | Returns 298 | ------- 299 | list 300 | Files grouped by the self.group_key_fn and subsequently sorted 301 | by the self.feature_sort_fn. 302 | """ 303 | new_files = [] 304 | grouped = itertools.groupby(sorted(files), lambda x: "_-_".join(os.path.basename(x).split("_-_")[:3])) 305 | for recording_name, feature_paths in grouped: 306 | new_files += [sorted(feature_paths, key=lambda x: "0" if x == "eeg" else x)] 307 | return new_files 308 | 309 | def __len__(self): 310 | return len(self.files) 311 | 312 | def __getitem__(self, recording_index): 313 | """Get data for a certain recording. 314 | 315 | Parameters 316 | ---------- 317 | recording_index: int 318 | Index of the recording in this dataset 319 | 320 | Returns 321 | ------- 322 | Union[Tuple[tf.Tensor,...], Tuple[np.ndarray,...]] 323 | The features corresponding to the recording_index recording 324 | """ 325 | data = [] 326 | for feature in self.files[recording_index]: 327 | data += [np.load(feature).astype(np.float32)] 328 | 329 | data = self.prepare_data(data) 330 | if self.return_filenames: 331 | return self.files[recording_index], tuple(tf.constant(x) for x in data) 332 | else: 333 | return tuple(tf.constant(x) for x in data) 334 | 335 | 336 | def __call__(self): 337 | """Load data for the next recording. 338 | 339 | Yields 340 | ------- 341 | Union[Tuple[tf.Tensor,...], Tuple[np.ndarray,...]] 342 | The features corresponding to the recording_index recording 343 | """ 344 | for idx in range(self.__len__()): 345 | yield self.__getitem__(idx) 346 | 347 | if idx == self.__len__() - 1: 348 | self.on_epoch_end() 349 | 350 | def on_epoch_end(self): 351 | """Change state at the end of an epoch.""" 352 | np.random.shuffle(self.files) 353 | 354 | def prepare_data(self, data): 355 | """ If specified, filter the data between highpass and lowpass 356 | :param data: list of numpy arrays, eeg and envelope 357 | :return: filtered data 358 | 359 | """ 360 | 361 | if self.filter_ is not None: 362 | resulting_data = [] 363 | # assuming time is the first dimension and channels the second 364 | resulting_data.append(scipy.signal.sosfiltfilt(self.filter_, data[0], axis=0)) 365 | 366 | for stimulus_feature in data[1:]: 367 | resulting_data.append(scipy.signal.sosfiltfilt(self.filter_, stimulus_feature, axis=0)) 368 | 369 | else: 370 | resulting_data = data 371 | 372 | return resulting_data 373 | 374 | 375 | -------------------------------------------------------------------------------- /preprocessing_code/sparrKULee.py: -------------------------------------------------------------------------------- 1 | """Run the default preprocessing pipeline on sparrKULee.""" 2 | import argparse 3 | import datetime 4 | import gzip 5 | import json 6 | import logging 7 | import os 8 | from typing import Any, Dict, Sequence 9 | 10 | import librosa 11 | import numpy as np 12 | import scipy.signal 13 | import scipy.signal.windows 14 | 15 | from brain_pipe.dataloaders.path import GlobLoader 16 | from brain_pipe.pipeline.default import DefaultPipeline 17 | from brain_pipe.preprocessing.brain.artifact import ( 18 | InterpolateArtifacts, 19 | ArtifactRemovalMWF, 20 | ) 21 | from brain_pipe.preprocessing.brain.eeg.biosemi import ( 22 | biosemi_trigger_processing_fn, 23 | ) 24 | from brain_pipe.preprocessing.brain.eeg.load import LoadEEGNumpy 25 | from brain_pipe.preprocessing.brain.epochs import SplitEpochs 26 | from brain_pipe.preprocessing.brain.link import ( 27 | LinkStimulusToBrainResponse, 28 | BIDSStimulusInfoExtractor, 29 | ) 30 | from brain_pipe.preprocessing.brain.rereference import CommonAverageRereference 31 | from brain_pipe.preprocessing.brain.trigger import ( 32 | AlignPeriodicBlockTriggers, 33 | ) 34 | from brain_pipe.preprocessing.filter import SosFiltFilt 35 | from brain_pipe.preprocessing.resample import ResamplePoly 36 | from brain_pipe.preprocessing.stimulus.audio.spectrogram import LibrosaMelSpectrogram 37 | from brain_pipe.preprocessing.stimulus.audio.envelope import GammatoneEnvelope 38 | from brain_pipe.preprocessing.stimulus.load import LoadStimuli 39 | from brain_pipe.runner.default import DefaultRunner 40 | from brain_pipe.save.default import DefaultSave 41 | from brain_pipe.utils.log import default_logging, DefaultFormatter 42 | from brain_pipe.utils.path import BIDSStimulusGrouper 43 | 44 | 45 | class BIDSAPRStimulusInfoExtractor(BIDSStimulusInfoExtractor): 46 | """Extract BIDS compliant stimulus information from an .apr file.""" 47 | 48 | def __call__(self, brain_dict: Dict[str, Any]): 49 | """Extract BIDS compliant stimulus information from an events.tsv file. 50 | 51 | Parameters 52 | ---------- 53 | brain_dict: Dict[str, Any] 54 | The data dict containing the brain data path. 55 | 56 | Returns 57 | ------- 58 | Sequence[Dict[str, Any]] 59 | The extracted event information. Each dict contains the information 60 | of one row in the events.tsv file 61 | """ 62 | event_info = super().__call__(brain_dict) 63 | # Find the apr file 64 | path = brain_dict[self.brain_path_key] 65 | apr_path = "_".join(path.split("_")[:-1]) + "_eeg.apr" 66 | # Read apr file 67 | apr_data = self.get_apr_data(apr_path) 68 | # Add apr data to event info 69 | for e_i in event_info: 70 | e_i.update(apr_data) 71 | return event_info 72 | 73 | def get_apr_data(self, apr_path: str): 74 | """Get the SNR from an .apr file. 75 | 76 | Parameters 77 | ---------- 78 | apr_path: str 79 | Path to the .apr file. 80 | 81 | Returns 82 | ------- 83 | Dict[str, Any] 84 | The SNR. 85 | """ 86 | import xml.etree.ElementTree as ET 87 | 88 | apr_data = {} 89 | tree = ET.parse(apr_path) 90 | root = tree.getroot() 91 | 92 | # Get SNR 93 | interactive_elements = root.findall(".//interactive/entry") 94 | for element in interactive_elements: 95 | description_element = element.find("description") 96 | if description_element.text == "SNR": 97 | apr_data["snr"] = element.find("new_value").text 98 | if "snr" not in apr_data: 99 | logging.warning(f"Could not find SNR in {apr_path}.") 100 | apr_data["snr"] = 100.0 101 | return apr_data 102 | 103 | 104 | def default_librosa_load_fn(path): 105 | """Load a stimulus using librosa. 106 | 107 | Parameters 108 | ---------- 109 | path: str 110 | Path to the audio file. 111 | 112 | Returns 113 | ------- 114 | Dict[str, Any] 115 | The data and the sampling rate. 116 | """ 117 | data, sr = librosa.load(path, sr=None) 118 | return {"data": data, "sr": sr} 119 | 120 | 121 | def default_npz_load_fn(path): 122 | """Load a stimulus from a .npz file. 123 | 124 | Parameters 125 | ---------- 126 | path: str 127 | Path to the .npz file. 128 | 129 | Returns 130 | ------- 131 | Dict[str, Any] 132 | The data and the sampling rate. 133 | """ 134 | np_data = np.load(path) 135 | return { 136 | "data": np_data["audio"], 137 | "sr": np_data["fs"], 138 | } 139 | 140 | 141 | DEFAULT_LOAD_FNS = { 142 | ".wav": default_librosa_load_fn, 143 | ".mp3": default_librosa_load_fn, 144 | ".npz": default_npz_load_fn, 145 | } 146 | 147 | 148 | def temp_stimulus_load_fn(path): 149 | """Load stimuli from (Gzipped) files. 150 | 151 | Parameters 152 | ---------- 153 | path: str 154 | Path to the stimulus file. 155 | 156 | Returns 157 | ------- 158 | Dict[str, Any] 159 | Dict containing the data under the key "data" and the sampling rate 160 | under the key "sr". 161 | """ 162 | if path.endswith(".gz"): 163 | with gzip.open(path, "rb") as f_in: 164 | data = dict(np.load(f_in)) 165 | return { 166 | "data": data["audio"], 167 | "sr": data["fs"], 168 | } 169 | 170 | extension = "." + ".".join(path.split(".")[1:]) 171 | if extension not in DEFAULT_LOAD_FNS: 172 | raise ValueError( 173 | f"Can't find a load function for extension {extension}. " 174 | f"Available extensions are {str(list(DEFAULT_LOAD_FNS.keys()))}." 175 | ) 176 | load_fn = DEFAULT_LOAD_FNS[extension] 177 | return load_fn(path) 178 | 179 | 180 | def bids_filename_fn(data_dict, feature_name, set_name=None): 181 | """Default function to generate a filename for the data. 182 | 183 | Parameters 184 | ---------- 185 | data_dict: Dict[str, Any] 186 | The data dict containing the data to save. 187 | feature_name: str 188 | The name of the feature. 189 | set_name: Optional[str] 190 | The name of the set. If no set name is given, the set name is not 191 | included in the filename. 192 | 193 | Returns 194 | ------- 195 | str 196 | The filename. 197 | """ 198 | 199 | filename = os.path.basename(data_dict["data_path"]).split("_eeg")[0] 200 | 201 | subject = filename.split("_")[0] 202 | session = filename.split("_")[1] 203 | filename += f"_desc-preproc-audio-{os.path.basename(data_dict.get('stimulus_path', '*.')).split('.')[0]}_{feature_name}" 204 | 205 | if set_name is not None: 206 | filename += f"_set-{set_name}" 207 | 208 | return os.path.join(subject, session, filename + ".npy") 209 | 210 | 211 | class SparrKULeeSpectrogramKwargs: 212 | """Default function to generate the kwargs for the librosa spectrogram.""" 213 | 214 | def __init__( 215 | self, 216 | stimulus_sr_key="stimulus_sr", 217 | target_fs=64, 218 | hop_length=None, 219 | win_length_sec=0.025, 220 | n_fft=None, 221 | window_fn=None, 222 | n_mels=28, 223 | fmin=-4.2735, 224 | fmax=5444, 225 | power=1.0, 226 | center=False, 227 | norm=None, 228 | htk=True, 229 | ): 230 | self.stimulus_sr_key = stimulus_sr_key 231 | self.target_fs = target_fs 232 | self.hop_length = hop_length 233 | self.win_length_sec = win_length_sec 234 | self.n_fft = n_fft 235 | self.window_fn = window_fn 236 | if window_fn is None: 237 | self.window_fn = scipy.signal.windows.hamming 238 | self.n_mels = n_mels 239 | self.fmin = fmin 240 | self.fmax = fmax 241 | self.power = power 242 | self.center = center 243 | self.norm = norm 244 | self.htk = htk 245 | 246 | def __call__(self, data_dict): 247 | """Default function to generate the kwargs for the librosa spectrogram. 248 | 249 | Parameters 250 | ---------- 251 | data_dict: Dict[str, Any] 252 | The data dict containing the data to save. 253 | 254 | Returns 255 | ------- 256 | Dict[str, Any] 257 | The kwargs for the librosa spectrogram. 258 | 259 | Notes 260 | ----- 261 | Code was based on the code for the 2023 Auditory EEG Challenge code: 262 | https://github.com/exporl/auditory-eeg-challenge-2023-code/blob/main/ 263 | task1_match_mismatch/util/mel_spectrogram.py 264 | """ 265 | fs = data_dict[self.stimulus_sr_key] 266 | result = { 267 | "fmin": self.fmin, 268 | "fmax": self.fmax, 269 | "n_mels": self.n_mels, 270 | "power": self.power, 271 | "center": self.center, 272 | "norm": self.norm, 273 | "htk": self.htk, 274 | } 275 | 276 | result["hop_length"] = self.hop_length 277 | if self.hop_length is None: 278 | result["hop_length"] = int((1 / self.target_fs) * fs) 279 | 280 | result["win_length"] = self.win_length_sec 281 | if self.win_length_sec is not None: 282 | result["win_length"] = int(self.win_length_sec * fs) 283 | 284 | result["n_fft"] = self.n_fft 285 | if self.n_fft is None: 286 | result["n_fft"] = int(2 ** np.ceil(np.log2(result["win_length"]))) 287 | 288 | result["window"] = self.window_fn(result["win_length"]) 289 | return result 290 | 291 | 292 | def run_preprocessing_pipeline( 293 | root_dir, 294 | preprocessed_stimuli_dir, 295 | preprocessed_eeg_dir, 296 | nb_processes=-1, 297 | overwrite=False, 298 | log_path="sparrKULee.log", 299 | ): 300 | """Construct and run the preprocessing on SparrKULee. 301 | 302 | Parameters 303 | ---------- 304 | root_dir: str 305 | The root directory of the dataset. 306 | preprocessed_stimuli_dir: 307 | The directory where the preprocessed stimuli should be saved. 308 | preprocessed_eeg_dir: 309 | The directory where the preprocessed EEG should be saved. 310 | nb_processes: int 311 | The number of processes to use. If -1, the number of processes is 312 | automatically determined. 313 | overwrite: bool 314 | Whether to overwrite existing files. 315 | log_path: str 316 | The path to the log file. 317 | """ 318 | ######### 319 | # PATHS # 320 | ######### 321 | os.makedirs(preprocessed_eeg_dir, exist_ok=True) 322 | os.makedirs(preprocessed_stimuli_dir, exist_ok=True) 323 | 324 | ########### 325 | # LOGGING # 326 | ########### 327 | handler = logging.FileHandler(log_path) 328 | handler.setLevel(logging.DEBUG) 329 | handler.setFormatter(DefaultFormatter()) 330 | default_logging(handlers=[handler]) 331 | 332 | ################ 333 | # DATA LOADING # 334 | ################ 335 | logging.info("Retrieving BIDS layout...") 336 | data_loader = GlobLoader( 337 | [os.path.join(root_dir, "sub-*", "*", "eeg", "*.bdf*")], 338 | filter_fns=[lambda x: "restingState" not in x], 339 | key="data_path", 340 | ) 341 | 342 | ######### 343 | # STEPS # 344 | ######### 345 | 346 | stimulus_steps = DefaultPipeline( 347 | steps=[ 348 | LoadStimuli(load_fn=temp_stimulus_load_fn), 349 | GammatoneEnvelope(), 350 | LibrosaMelSpectrogram( 351 | power_factor=0.6, librosa_kwargs=SparrKULeeSpectrogramKwargs() 352 | ), 353 | ResamplePoly(64, "envelope_data", "stimulus_sr"), 354 | # Comment out the next line if you don't want to use mel 355 | DefaultSave( 356 | preprocessed_stimuli_dir, 357 | to_save={ 358 | "envelope": "envelope_data", 359 | # Comment out the next line if you don't want to use mel 360 | "mel": "spectrogram_data", 361 | }, 362 | overwrite=overwrite, 363 | ), 364 | DefaultSave(preprocessed_stimuli_dir, overwrite=overwrite), 365 | ], 366 | on_error=DefaultPipeline.RAISE, 367 | ) 368 | 369 | eeg_steps = [ 370 | LinkStimulusToBrainResponse( 371 | stimulus_data=stimulus_steps, 372 | extract_stimuli_information_fn=BIDSAPRStimulusInfoExtractor(), 373 | grouper=BIDSStimulusGrouper( 374 | bids_root=root_dir, 375 | mapping={"stim_file": "stimulus_path", "trigger_file": "trigger_path"}, 376 | subfolders=["stimuli", "eeg"], 377 | ), 378 | ), 379 | LoadEEGNumpy(unit_multiplier=1e6, channels_to_select=list(range(64))), 380 | SosFiltFilt( 381 | scipy.signal.butter(1, 0.5, "highpass", fs=1024, output="sos"), 382 | emulate_matlab=True, 383 | axis=1, 384 | ), 385 | InterpolateArtifacts(), 386 | AlignPeriodicBlockTriggers(biosemi_trigger_processing_fn), 387 | SplitEpochs(), 388 | ArtifactRemovalMWF(), 389 | CommonAverageRereference(), 390 | ResamplePoly(64, axis=1), 391 | DefaultSave( 392 | preprocessed_eeg_dir, 393 | {"eeg": "data"}, 394 | overwrite=overwrite, 395 | clear_output=True, 396 | filename_fn=bids_filename_fn, 397 | ), 398 | ] 399 | 400 | ######################### 401 | # RUNNING THE PIPELINE # 402 | ######################### 403 | 404 | logging.info("Starting with the EEG preprocessing") 405 | logging.info("===================================") 406 | 407 | # Create data_dicts for the EEG files 408 | # Create the EEG pipeline 409 | eeg_pipeline = DefaultPipeline(steps=eeg_steps) 410 | 411 | DefaultRunner( 412 | nb_processes=nb_processes, 413 | logging_config=lambda: None, 414 | ).run( 415 | [(data_loader, eeg_pipeline)], 416 | ) 417 | 418 | 419 | if __name__ == "__main__": 420 | # Code for the sparrKULee dataset 421 | # (https://rdr.kuleuven.be/dataset.xhtml?persistentId=doi:10.48804/K3VSND) 422 | # 423 | # A slight adaption of this code can also be found in the spaRRKULee repository: 424 | # https://github.com/exporl/auditory-eeg-dataset 425 | # under preprocessing_code/sparrKULee.py 426 | # Load the config 427 | with open("config.json", "r") as f: 428 | config = json.load(f) 429 | 430 | # Set the correct paths as default arguments 431 | dataset_folder = config["dataset_folder"] 432 | derivatives_folder = os.path.join(dataset_folder, config["derivatives_folder"]) 433 | preprocessed_stimuli_folder = os.path.join( 434 | derivatives_folder, config["preprocessed_stimuli_folder"] 435 | ) 436 | preprocessed_eeg_folder = os.path.join( 437 | derivatives_folder, config["preprocessed_eeg_folder"] 438 | ) 439 | # Set the default log folder 440 | default_log_folder = os.path.dirname(os.path.abspath(__file__)) 441 | 442 | # Parse arguments from the command line 443 | parser = argparse.ArgumentParser(description="Preprocess the auditory EEG dataset") 444 | parser.add_argument( 445 | "--nb_processes", 446 | type=int, 447 | default=-1, 448 | help="Number of processes to use for the preprocessing. " 449 | "The default is to use all available cores (-1).", 450 | ) 451 | parser.add_argument( 452 | "--overwrite", action="store_true", help="Overwrite existing files" 453 | ) 454 | parser.add_argument( 455 | "--log_path", type=str, default=os.path.join( 456 | default_log_folder, 457 | "sparrKULee_{datetime}.log" 458 | ) 459 | ) 460 | parser.add_argument( 461 | "--dataset_folder", 462 | type=str, 463 | default=dataset_folder, 464 | help="Path to the folder where the dataset is downloaded", 465 | ) 466 | parser.add_argument( 467 | "--preprocessed_stimuli_path", 468 | type=str, 469 | default=preprocessed_stimuli_folder, 470 | help="Path to the folder where the preprocessed stimuli will be saved", 471 | ) 472 | parser.add_argument( 473 | "--preprocessed_eeg_path", 474 | type=str, 475 | default=preprocessed_eeg_folder, 476 | help="Path to the folder where the preprocessed EEG will be saved", 477 | ) 478 | args = parser.parse_args() 479 | 480 | # Run the preprocessing pipeline 481 | run_preprocessing_pipeline( 482 | args.dataset_folder, 483 | args.preprocessed_stimuli_path, 484 | args.preprocessed_eeg_path, 485 | args.nb_processes, 486 | args.overwrite, 487 | args.log_path.format( 488 | datetime=datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 489 | ), 490 | ) 491 | --------------------------------------------------------------------------------