├── .gitignore ├── LICENSE ├── README.md ├── moda ├── 1_checking_moda_content.py ├── 2_generate_moda_npz_files.py ├── 3_generate_moda_segments.py └── __init__.py ├── nsrr ├── __init__.py ├── check_dataset_object.py ├── check_npz_stats.py ├── check_npz_std.py ├── check_raw_data.py ├── check_raw_metadata.py ├── explore_metadata.ipynb ├── nsrr_utils.py ├── prepare_data.py └── prepare_metadata.py ├── requirements.txt ├── resources ├── .keep └── invalid_nsrr_subjects.txt ├── scripts ├── 1_train.py ├── 2_crossval_performance.py └── nsrr_inference.py └── sleeprnn ├── __init__.py ├── common ├── __init__.py ├── checks.py ├── constants.py ├── optimal_thresholds.py ├── pkeys.py └── viz.py ├── data ├── __init__.py ├── cap_ss.py ├── dataset.py ├── inta_ss.py ├── mass_kc.py ├── mass_raw.py ├── mass_ss.py ├── moda_ss.py ├── nsrr_ss.py ├── pink.py ├── stamp_correction.py └── utils.py ├── detection ├── __init__.py ├── det_utils.py ├── ensemble.py ├── feeder_dataset.py ├── metrics.py ├── postprocessing.py ├── postprocessor.py ├── predicted_dataset.py ├── simple_detection.py └── threshold_optimization.py ├── helpers ├── __init__.py ├── misc.py ├── performer.py ├── plotter.py ├── printer.py ├── reader.py └── sharing.py └── nn ├── __init__.py ├── adam_w.py ├── augmentations.py ├── base_model.py ├── expert_feats.py ├── feeding.py ├── layers.py ├── losses.py ├── metrics.py ├── models.py ├── networks.py ├── networks_v2.py ├── networks_v3.py ├── optimizers.py ├── spectrum.py └── wave_augment.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Folders 2 | bin 3 | tmp 4 | .ipynb_checkpoints 5 | __pychache__ 6 | .idea 7 | .loadpath 8 | .recommenders 9 | .settings 10 | datasets 11 | results 12 | experiments 13 | 14 | 15 | # Extensions 16 | *.DS_Store 17 | *.pyc 18 | *.bak 19 | *.swp 20 | *.mat 21 | *.edf 22 | *.rec 23 | 24 | 25 | # External tool builders 26 | .externalToolBuilders/ 27 | 28 | # Locally stored "Eclipse launch configurations" 29 | *.launch 30 | 31 | # PyDev specific (Python IDE for Eclipse) 32 | *.pydevproject 33 | 34 | # CDT-specific (C/C++ Development Tooling) 35 | .cproject 36 | 37 | # Java annotation processor (APT) 38 | .factorypath 39 | 40 | # PDT-specific (PHP Development Tools) 41 | .buildpath 42 | 43 | # sbteclipse plugin 44 | .target 45 | 46 | # Tern plugin 47 | .tern-project 48 | 49 | # TeXlipse plugin 50 | .texlipse 51 | 52 | # STS (Spring Tool Suite) 53 | .springBeans 54 | 55 | # Code Recommenders 56 | .recommenders/ 57 | 58 | # Scala IDE specific (Scala & Java development for Eclipse) 59 | .cache-main 60 | .scala_dependencies 61 | .worksheet 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Nicolás I. Tapia Rivas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sleep EEG Event Detector (SEED) 2 | 3 | Repository for the code and pretrained weights of our deep-learning based detector (SEED) described in: 4 | 5 | Tapia-Rivas, N.I., Estévez, P.A. & Cortes-Briones, J.A. A robust deep learning detector for sleep spindles and K-complexes: towards population norms. _Sci Rep_ **14**, 263 (2024). 6 | https://doi.org/10.1038/s41598-023-50736-7 7 | 8 | If you find this software useful, please consider citing our work. 9 | 10 | ## Setup 11 | 12 | SEED is implemented using TensorFlow 1 in python 3.9. 13 | 14 | For a safe installation, create a virtual environment with `python` 3.9. For example, if you use `conda`: 15 | ```bash 16 | conda create -n seed python=3.9 17 | conda activate seed 18 | ``` 19 | 20 | Inside the environment, install dependencies running `pip install -r requirements.txt` 21 | 22 | 23 | 24 | ## Getting started 25 | 26 | In the current state of the code, your simplest entrypoint is `/scripts/`. 27 | - `train.py`: Trains SEED, and generates predictions of the final model. 28 | - `crossval_performance.py`: For a given training run, it fits the detection threshold of SEED and reports the cross-validation performance of that optimal threshold. 29 | - `nsrr_inference.py`: Script that uses an ensemble of trained SEED models to predict sleep spindles on the NSRR dataset. 30 | 31 | Inside the scripts you will find further documentation. 32 | 33 | ## How to load data 34 | 35 | The code loads a dataset with the function `load_dataset` defined in `sleeprnn/helpers/reader.py`. For example, to load the MODA dataset, the function loads the class `ModaSS` defined in `sleeprnn/data/moda_ss.py`. All of these datasets are sub-classes of the `Dataset` base class that is defined in `sleeprnn/data/dataset.py`. 36 | 37 | ### Use a dataset used by our research 38 | 39 | If you want to use MASS data ([MASS paper](https://pubmed.ncbi.nlm.nih.gov/24909981/), [MODA paper](https://www.nature.com/articles/s41597-020-0533-4)), the easiest way is to organize your data files so that one of the following classes (defined in `sleeprnn/data/`) can be instantiated, according to the expected directory tree documented in each class definition: 40 | - `MassSS`: Class to instantiate the MASS-SS2 dataset considering their sleep spindle annotations, from both experts. 41 | - `MassKC`: Class to instantiate the MASS-SS2 dataset considering their K-complexes annotations. 42 | - `ModaSS`: Class to instantiate the MODA dataset with its sleep spindle annotations. The MODA dataset is composed of signal segments extracted from the full MASS dataset (that is, from its five subsets, not only MASS-SS2), that were annotated for sleep spindles by a consensus of experts. To load this class, first run the scripts located in the `moda/` directory. Inside each script you will find further information. 43 | 44 | ### Use a dataset of your own 45 | 46 | On the other hand, if you want to use your own dataset, the easiest way in the current state of the code is to create your own subclass of the `Dataset` base class in the `sleeprnn/data/` package. 47 | 48 | You can use the existing subclasses as examples for the implemenation. The base class has documentation for the expected arguments in its constructor. Besides giving these arguments, you must implement the `_load_from_source` method, that is in charge of reading raw files and return the data dictionary, as illustrated by the template implementation in `Dataset`, or as you can also see on actual implementations for MASS-SS2 and MODA. Note that the implemented function not only has to read data and structure it in the expected format, but you also need to preprocess the data (most likely just bandpass filtering the EEG). Again, see the current implementations for MASS-SS2 as an example. 49 | 50 | Once you create your own subclass of `Dataset`, you can add it as another option in the function `load_dataset` defined in `sleeprnn/helpers/reader.py`, so that it can be easily loaded by the training and fitting scripts by name. 51 | 52 | As a future improvement, I would like to offer a more flexible option. For now, this is the way I implemented datasets to handle various steps that I needed for my experiments. 53 | 54 | 55 | ## Pending tasks 56 | 57 | With the goal of sharing working code as fast as possible, I decided to share my original research code directly as a first step. 58 | 59 | Therefore, it's quite messy. It contains many outdated packages (whose versions are specified in `requirements.txt`) and many pieces of code that are not used by the final published model. My plan is to clean and refactor the codebase so that only the published model is present and it's easy to use by users to either train it on your own data, or use checkpoints to predict on it. 60 | 61 | These are some known pending tasks for the future: 62 | 63 | - [ ] Upload and share existing model checkpoints that were used for the paper (TensorFlow 1). 64 | - [ ] Improve code: clean unused model variants, migrate to TensorFlow 2, improve documentation, simplify process to use custom data, add example notebooks. 65 | - [ ] Generate and share new checkpoints following the improved code. 66 | -------------------------------------------------------------------------------- /moda/1_checking_moda_content.py: -------------------------------------------------------------------------------- 1 | """Some EDA of the signals that compose the MODA database. 2 | 3 | Ref: 4 | Lacourse, K., Yetton, B., Mednick, S. et al. 5 | Massive online data annotation, crowdsourcing to generate high quality sleep spindle annotations from EEG data. 6 | Sci Data 7, 190 (2020). 7 | https://doi.org/10.1038/s41597-020-0533-4 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os 15 | import sys 16 | from pprint import pprint 17 | 18 | 19 | import numpy as np 20 | import pyedflib 21 | 22 | project_root = os.path.abspath("..") 23 | sys.path.append(project_root) 24 | 25 | # Change this to the path where the MASS dataset is stored: 26 | PATH_MODA_RAW = "/home/ntapia/Projects/Sleep_Databases/MASS_Database_2020_Full/C1" 27 | 28 | 29 | def get_filepaths(main_path): 30 | files = os.listdir(main_path) 31 | files = [f for f in files if ".edf" in f] 32 | signal_files = [f for f in files if "PSG" in f] 33 | states_files = [f for f in files if "Base" in f] 34 | signal_files = [os.path.join(main_path, f) for f in signal_files] 35 | states_files = [os.path.join(main_path, f) for f in states_files] 36 | signal_files.sort() 37 | states_files.sort() 38 | return signal_files, states_files 39 | 40 | 41 | if __name__ == "__main__": 42 | # Using sleep state information is not necessary for the MODA dataset, because 43 | # sleep spindle annotations are only made for valid sleep states 44 | 45 | required_channel = "C3" 46 | 47 | signal_files, states_files = get_filepaths(PATH_MODA_RAW) 48 | assert len(signal_files) == len(states_files) 49 | print("%d subjects" % len(signal_files)) 50 | 51 | all_channel_names = [] 52 | all_fs = [] 53 | 54 | n_max = 10 55 | for signal_f, states_f in zip(signal_files[:n_max], states_files[:n_max]): 56 | subject_id_1 = signal_f.split("/")[-1].split(" ")[0] 57 | subject_id_2 = states_f.split("/")[-1].split(" ")[0] 58 | assert subject_id_1 == subject_id_2 59 | subject_id = subject_id_1 60 | print("ID %s" % subject_id) 61 | # Read signal 62 | with pyedflib.EdfReader(signal_f) as file: 63 | channel_names = file.getSignalLabels() 64 | channel_names_valid = [ 65 | chn for chn in channel_names if required_channel in chn 66 | ] 67 | fs_valid = [] 68 | for chn in channel_names_valid: 69 | extraction_loc = channel_names.index(chn) 70 | fs_original = file.samplefrequency(extraction_loc) 71 | chn_check = file.getLabel(extraction_loc) 72 | assert chn_check == chn 73 | fs_valid.append(fs_original) 74 | print(" Channels", channel_names_valid) 75 | print(" Freqs ", fs_valid) 76 | all_channel_names.append(channel_names_valid) 77 | all_fs.append(fs_valid) 78 | 79 | for l in [all_channel_names, all_fs]: 80 | l = np.concatenate(l).flatten() 81 | values, counts = np.unique(l, return_counts=True) 82 | for v, c in zip(values, counts): 83 | print("Value %s with count %s" % (v, c)) 84 | -------------------------------------------------------------------------------- /moda/2_generate_moda_npz_files.py: -------------------------------------------------------------------------------- 1 | """Extracts only the relevant data from the full MODA database to save storage. 2 | 3 | It saves the extracted signal for each subject in independent npz files under 4 | the resources/datasets/moda/signals_npz/ directory. 5 | These files are the input for the generate_moda_segments.py script. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | import sys 14 | from pprint import pprint 15 | 16 | import numpy as np 17 | import pandas as pd 18 | import pyedflib 19 | 20 | project_root = os.path.abspath("..") 21 | sys.path.append(project_root) 22 | 23 | # Change this to the path where the MASS dataset is stored: 24 | PATH_MODA_RAW = "/home/ntapia/Projects/Sleep_Databases/MASS_Database_2020_Full/C1" 25 | # Change this to the path where the 8_MODA_primChan_180sjt.txt file is stored: 26 | PATH_SUBJECT_CHANNEL_INFO = "../resources/datasets/moda/8_MODA_primChan_180sjt.txt" 27 | 28 | 29 | def get_signal(file, chn_name): 30 | channel_names = file.getSignalLabels() 31 | channel_loc = channel_names.index(chn_name) 32 | check = file.getLabel(channel_loc) 33 | assert check == chn_name 34 | fs = file.samplefrequency(channel_loc) 35 | signal = file.readSignal(channel_loc) 36 | return signal, fs 37 | 38 | 39 | if __name__ == "__main__": 40 | save_dir = "../resources/datasets/moda/signals_npz" 41 | save_dir = os.path.abspath(save_dir) 42 | os.makedirs(save_dir, exist_ok=True) 43 | print("Files will be saved at %s" % save_dir) 44 | 45 | info = pd.read_csv(PATH_SUBJECT_CHANNEL_INFO, delimiter="\t") 46 | subject_ids = info.subject.values 47 | subject_ids = [s.split(".")[0] for s in subject_ids] 48 | channels_for_moda = info.channel.values 49 | 50 | n_subjects = len(subject_ids) 51 | print("%d subjects" % n_subjects) 52 | for i in range(n_subjects): 53 | subject_id = subject_ids[i] 54 | channel_for_moda = channels_for_moda[i] 55 | signal_f = os.path.join(PATH_MODA_RAW, "%s PSG.edf" % subject_id) 56 | print("Loading %s from %s" % (channel_for_moda, signal_f)) 57 | with pyedflib.EdfReader(signal_f) as file: 58 | channel_names = file.getSignalLabels() 59 | if channel_for_moda == "C3-A2": 60 | # re-reference 61 | required_channel = [chn for chn in channel_names if "C3" in chn][0] 62 | reference_channel = [chn for chn in channel_names if "A2" in chn][0] 63 | required_signal, required_fs = get_signal(file, required_channel) 64 | reference_signal, reference_fs = get_signal(file, reference_channel) 65 | assert required_fs == reference_fs 66 | signal = required_signal - reference_signal 67 | channel_extracted = "(%s)-(%s)" % (required_channel, reference_channel) 68 | else: 69 | # use channel as-is 70 | required_channel = [chn for chn in channel_names if "C3" in chn][0] 71 | required_signal, required_fs = get_signal(file, required_channel) 72 | signal = required_signal 73 | channel_extracted = required_channel 74 | signal = signal.astype(np.float32) 75 | fs = int(required_fs) 76 | print( 77 | "(%03d/%03d) Subject %s, sampling %s Hz, channel %s" 78 | % (i + 1, n_subjects, subject_id, fs, channel_extracted), 79 | flush=True, 80 | ) 81 | data_dict = { 82 | "dataset_id": "MASS-C1", 83 | "subject_id": subject_id, 84 | "sampling_rate": fs, 85 | "channel": channel_extracted, 86 | "signal": signal, 87 | } 88 | fname = os.path.join(save_dir, "moda_%s.npz" % subject_id) 89 | np.savez(fname, **data_dict) 90 | -------------------------------------------------------------------------------- /moda/3_generate_moda_segments.py: -------------------------------------------------------------------------------- 1 | """Generates moda_preprocessed_segments.npz (~160 MB) and metadata.csv (~50 kB). 2 | 3 | The file is a pre-processed (cropped and filtered) MODA dataset that includes only the 4 | annotated portions. 5 | Before running this script, run generate_moda_npz_files.py because it uses the files 6 | generated by that script. 7 | 8 | The generated file is the information source actually used to create a python 9 | dataset of MODA to feed the models. 10 | """ 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import os 17 | import sys 18 | from pprint import pprint 19 | 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | import pandas as pd 23 | from scipy.signal import butter, sosfiltfilt 24 | 25 | project_root = os.path.abspath("..") 26 | sys.path.append(project_root) 27 | 28 | from sleeprnn.data import utils 29 | 30 | # In this path we expect to find MODA metadata files and the npz files generated by 31 | # generate_moda_npz_files.py script: 32 | MODA_PATH = "../resources/datasets/moda" 33 | 34 | 35 | def get_subjects(): 36 | p1_info = pd.read_csv( 37 | os.path.join(MODA_PATH, "6_segListSrcDataLoc_p1.txt"), delimiter="\t" 38 | ) 39 | p1_subjects = np.unique(p1_info.subjectID.values) 40 | p2_info = pd.read_csv( 41 | os.path.join(MODA_PATH, "7_segListSrcDataLoc_p2.txt"), delimiter="\t" 42 | ) 43 | p2_subjects = np.unique(p2_info.subjectID.values) 44 | phase_dict = {} 45 | for subject_id in p1_subjects: 46 | phase_dict[subject_id] = 1 47 | for subject_id in p2_subjects: 48 | phase_dict[subject_id] = 2 49 | subject_ids = list(phase_dict.keys()) 50 | subject_ids.sort() 51 | return subject_ids, phase_dict 52 | 53 | 54 | def get_data(subject_id): 55 | data = np.load(os.path.join(MODA_PATH, "signals_npz/moda_%s.npz" % subject_id)) 56 | fs = data["sampling_rate"].item() 57 | channel = data["channel"].item() 58 | signal = data["signal"] 59 | return signal, fs, channel 60 | 61 | 62 | def get_annotations(subject_id): 63 | annot = pd.read_csv( 64 | os.path.join(MODA_PATH, "MODA_annotFiles/%s_MODA_GS.txt" % subject_id), 65 | delimiter="\t", 66 | ) 67 | segments_info = annot[annot.eventName == "segmentViewed"] 68 | segments_start = np.sort(segments_info.startSec.values) 69 | spindles_info = annot[annot.eventName == "spindle"] 70 | spindles_start = spindles_info.startSec.values 71 | spindles_end = spindles_info.durationSec.values + spindles_start 72 | return segments_start, spindles_start, spindles_end 73 | 74 | 75 | def get_segment( 76 | x, 77 | fs, 78 | start_time, 79 | segment_duration=115, 80 | border_duration=30, 81 | border_to_filter_duration=10, 82 | ): 83 | # Extract a single segment of EEG 84 | total_border = border_duration * fs + border_to_filter_duration * fs 85 | segment_size = 2 * total_border + segment_duration * fs 86 | start_sample = int(start_time * fs - total_border) 87 | end_sample = int(start_sample + segment_size) 88 | segment = x[start_sample:end_sample].copy() 89 | return segment 90 | 91 | 92 | def filter_segment( 93 | x, fs, lowcut=0.3, highcut=30, filter_order=10, border_to_filter_duration=10 94 | ): 95 | sos = butter(filter_order, lowcut, btype="highpass", fs=fs, output="sos") 96 | x = sosfiltfilt(sos, x) 97 | sos = butter(filter_order, highcut, btype="lowpass", fs=fs, output="sos") 98 | x = sosfiltfilt(sos, x) 99 | border_size = int(fs * border_to_filter_duration) 100 | return x[border_size:-border_size] 101 | 102 | 103 | def get_label(binaries, fs, start_time, segment_duration=115, border_duration=30): 104 | border_size = int(fs * border_duration) 105 | segment_size = 2 * border_size + segment_duration * fs 106 | start_sample = int(start_time * fs - border_size) 107 | end_sample = int(start_sample + segment_size) 108 | labels = binaries[start_sample:end_sample].copy() 109 | labels[:border_size] = -1 110 | labels[-border_size:] = -1 111 | 112 | return labels 113 | 114 | 115 | if __name__ == "__main__": 116 | save_dir = "../resources/datasets/moda/segments" 117 | save_dir = os.path.abspath(save_dir) 118 | os.makedirs(save_dir, exist_ok=True) 119 | print("Files will be saved at %s" % save_dir) 120 | 121 | metadata_l = [] 122 | segments_signal_l = [] 123 | segments_labels_l = [] 124 | segments_subjects_l = [] 125 | segments_phases_l = [] 126 | 127 | subject_ids, phase_dict = get_subjects() 128 | for subject_id in subject_ids: 129 | print(subject_id, flush=True) 130 | signal, fs, channel = get_data(subject_id) 131 | segments_start, spindles_start, spindles_end = get_annotations(subject_id) 132 | stamps_time = np.stack([spindles_start, spindles_end], axis=1) 133 | stamps = (stamps_time * fs).astype(np.int32) 134 | binary_labels = utils.stamp2seq(stamps, 0, signal.size - 1) 135 | for single_start in segments_start: 136 | segment_signal_prefilter = get_segment(signal, fs, single_start) 137 | segment_signal = filter_segment(segment_signal_prefilter, fs).astype( 138 | np.float32 139 | ) 140 | segment_label = get_label(binary_labels, fs, single_start).astype(np.int8) 141 | # Append data 142 | segments_signal_l.append(segment_signal) 143 | segments_labels_l.append(segment_label) 144 | segments_subjects_l.append(subject_id) 145 | segments_phases_l.append(phase_dict[subject_id]) 146 | metadata_l.append( 147 | { 148 | "subject_id": subject_id, 149 | "phase": phase_dict[subject_id], 150 | "channel": channel, 151 | "fs": fs, 152 | "start_seconds": single_start, 153 | "segment_seconds": 115, 154 | "border_seconds": 30, 155 | } 156 | ) 157 | 158 | # Format data 159 | segments_signal = np.stack(segments_signal_l, axis=0) 160 | segments_labels = np.stack(segments_labels_l, axis=0) 161 | segments_subjects = np.stack(segments_subjects_l, axis=0) 162 | segments_phases = np.stack(segments_phases_l, axis=0).astype(np.int8) 163 | metadata_table = pd.DataFrame(metadata_l) 164 | 165 | # Save data 166 | np.savez( 167 | os.path.join(save_dir, "moda_preprocessed_segments.npz"), 168 | signals=segments_signal, 169 | labels=segments_labels, 170 | subjects=segments_subjects, 171 | phases=segments_phases, 172 | ) 173 | metadata_table.to_csv(os.path.join(save_dir, "metadata.csv"), sep="\t") 174 | -------------------------------------------------------------------------------- /moda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolasigor/Sleep-EEG-Event-Detector/24322278e3f3ef7535413a65a61fffc3ce4f4e01/moda/__init__.py -------------------------------------------------------------------------------- /nsrr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolasigor/Sleep-EEG-Event-Detector/24322278e3f3ef7535413a65a61fffc3ce4f4e01/nsrr/__init__.py -------------------------------------------------------------------------------- /nsrr/check_dataset_object.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pprint import pprint 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | project_root = ".." 9 | sys.path.append(project_root) 10 | 11 | from sleeprnn.data.nsrr_ss import NsrrSS 12 | 13 | if __name__ == "__main__": 14 | NsrrSS() 15 | -------------------------------------------------------------------------------- /nsrr/check_npz_stats.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | 6 | project_root = ".." 7 | sys.path.append(project_root) 8 | 9 | DATASETS_PATH = os.path.join(project_root, "resources", "datasets", "nsrr") 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | dataset_name_list = ["shhs1", "mros1", "chat1", "sof", "cfs", "ccshs"] 15 | 16 | # Keys in dataset: 17 | # 'dataset' 18 | # 'subject_id' 19 | # 'channel' 20 | # 'signal' 21 | # 'sampling_rate' 22 | # 'hypnogram' 23 | # 'epoch_duration' 24 | # 'bandpass_filter' 25 | # 'resampling_function' 26 | # 'original_sampling_rate' 27 | 28 | n2_id = "Stage 2 sleep|2" 29 | 30 | for dataset_name in dataset_name_list: 31 | print("Check %s" % dataset_name) 32 | 33 | npz_dir = os.path.abspath( 34 | os.path.join(DATASETS_PATH, dataset_name, "register_and_state") 35 | ) 36 | all_files = os.listdir(npz_dir) 37 | all_files.sort() 38 | all_files = [f for f in all_files if ".npz" in f] 39 | all_fname = np.array(all_files) 40 | all_files = [os.path.join(npz_dir, f) for f in all_files] 41 | all_files = np.array(all_files) 42 | 43 | all_original_fs = [] 44 | all_channel = [] 45 | duration_in_seconds = 0 46 | stage_labels = [] 47 | for f in all_files: 48 | data_dict = np.load(f) 49 | 50 | # check original fs 51 | original_fs = data_dict["original_sampling_rate"] 52 | all_original_fs.append(original_fs) 53 | 54 | # check channel extracted 55 | channel = data_dict["channel"] 56 | all_channel.append(channel) 57 | 58 | # check N2 duration 59 | hypnogram = data_dict["hypnogram"] 60 | stage_labels.append(np.unique(hypnogram)) 61 | epoch_duration = data_dict["epoch_duration"] 62 | n2_pages = (hypnogram == n2_id).sum() 63 | n2_duration = epoch_duration * n2_pages 64 | duration_in_seconds += n2_duration 65 | stage_labels = np.unique(np.concatenate(stage_labels)) 66 | 67 | print("\nReport:") 68 | print("Subjects %d" % len(all_files)) 69 | print( 70 | "Duration: %1.4f s, %1.4f h" 71 | % (duration_in_seconds, duration_in_seconds / 3600) 72 | ) 73 | 74 | print("\nOriginal fs found:") 75 | values, counts = np.unique(all_original_fs, return_counts=True) 76 | for v, c in zip(values, counts): 77 | print("%s: %d" % (v, c)) 78 | 79 | print("\nChannels found:") 80 | values, counts = np.unique(all_channel, return_counts=True) 81 | for v, c in zip(values, counts): 82 | print("%s: %d" % (v, c)) 83 | 84 | print("\nSleep stage labels found:", stage_labels) 85 | -------------------------------------------------------------------------------- /nsrr/check_npz_std.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | 6 | project_root = ".." 7 | sys.path.append(project_root) 8 | 9 | DATASETS_PATH = os.path.join(project_root, "resources", "datasets", "nsrr") 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | dataset_name_list = [ 15 | "shhs1", 16 | ] 17 | 18 | top_k = 40 19 | epoch_samples = int(200 * 30) 20 | for dataset_name in dataset_name_list: 21 | print("Check %s" % dataset_name) 22 | npz_dir = os.path.abspath( 23 | os.path.join(DATASETS_PATH, dataset_name, "register_and_state") 24 | ) 25 | all_files = os.listdir(npz_dir) 26 | all_files.sort() 27 | all_files = [f for f in all_files if ".npz" in f] 28 | all_fname = np.array(all_files) 29 | all_files = [os.path.join(npz_dir, f) for f in all_files] 30 | all_files = np.array(all_files) 31 | 32 | all_std = [] 33 | all_n_pages = [] 34 | all_channels = [] 35 | for f in all_files: 36 | data_dict = np.load(f) 37 | tmp_signal = data_dict["signal"] 38 | tmp_std = tmp_signal.std() 39 | tmp_n_pages = tmp_signal.size / epoch_samples 40 | all_std.append(tmp_std) 41 | all_n_pages.append(tmp_n_pages) 42 | all_channels.append(data_dict["channel"]) 43 | print("Loaded %s" % f) 44 | all_std = np.array(all_std) 45 | all_n_pages = np.array(all_n_pages) 46 | 47 | print("\nReport:") 48 | print("Subjects %d" % len(all_std)) 49 | print( 50 | "STD - min %s, mean %s, max %s" 51 | % (all_std.min(), all_std.mean(), all_std.max()) 52 | ) 53 | print( 54 | "Pages - min %s, mean %s, max %s" 55 | % (all_n_pages.min(), all_n_pages.mean(), all_n_pages.max()) 56 | ) 57 | 58 | print("Channels found:") 59 | values, counts = np.unique(all_channels, return_counts=True) 60 | for v, c in zip(values, counts): 61 | print("%s: %d" % (v, c)) 62 | 63 | print("\nSmallest STD values: (Top %d)" % top_k) 64 | sorted_locs = np.argsort(all_std) 65 | for loc in sorted_locs[:top_k]: 66 | print( 67 | " File %s, STD %s, Pages %s" 68 | % (all_fname[loc], all_std[loc], all_n_pages[loc]) 69 | ) 70 | -------------------------------------------------------------------------------- /nsrr/check_raw_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | 6 | sys.path.append("..") 7 | 8 | from nsrr import nsrr_utils 9 | from nsrr.nsrr_utils import NSRR_DATA_PATHS, CHANNEL_PRIORITY_LABELS 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | dataset_name = "shhs1" 15 | verbose_missing_epoch = False 16 | reduced_number_of_subjects = None 17 | channel_pairs_list = None 18 | 19 | # ################################################################ 20 | 21 | print("Check %s" % dataset_name) 22 | edf_folder = NSRR_DATA_PATHS[dataset_name]["edf"] 23 | annot_folder = NSRR_DATA_PATHS[dataset_name]["annot"] 24 | print("Paths:") 25 | print(edf_folder) 26 | print(annot_folder) 27 | 28 | paths_dict = nsrr_utils.prepare_paths(edf_folder, annot_folder) 29 | subject_ids = list(paths_dict.keys()) 30 | 31 | if reduced_number_of_subjects is not None: 32 | # Reduced subset 33 | subject_ids = subject_ids[:reduced_number_of_subjects] 34 | 35 | print("Retrieved subjects: %d" % len(subject_ids)) 36 | 37 | epoch_length_list = [] 38 | first_label_start_list = [] 39 | channel_ids_list = [] 40 | for subject_id in subject_ids: 41 | edf_path = paths_dict[subject_id]["edf"] 42 | annot_path = paths_dict[subject_id]["annot"] 43 | 44 | # Hypnogram info 45 | stage_labels, stage_start_times, epoch_length = nsrr_utils.read_hypnogram( 46 | annot_path, verbose=verbose_missing_epoch 47 | ) 48 | epoch_length_list.append(epoch_length) 49 | first_label_start_list.append(stage_start_times[0]) 50 | 51 | total_pages = (stage_start_times[-1] + epoch_length) / epoch_length 52 | labeled_pages = len(stage_labels) 53 | if labeled_pages != total_pages: 54 | print( 55 | "Subject %s, hypno: %d labels, epochLength %s, first start %s, last start %s, required labels %s" 56 | % ( 57 | subject_id, 58 | len(stage_labels), 59 | epoch_length, 60 | stage_start_times[0], 61 | stage_start_times[-1], 62 | total_pages, 63 | ) 64 | ) 65 | 66 | # Signal info 67 | channel_names, fs_list = nsrr_utils.get_edf_info(edf_path) 68 | channel_found = None 69 | if channel_pairs_list is None: 70 | channel_pairs_list = CHANNEL_PRIORITY_LABELS 71 | for chn_pair in channel_pairs_list: 72 | if np.all([chn in channel_names for chn in chn_pair]): 73 | channel_found = chn_pair 74 | break 75 | if channel_found is None: 76 | print( 77 | "Subject %s without valid channels. Full list:" % subject_id, 78 | channel_names, 79 | ) 80 | else: 81 | channel_loc_1 = channel_names.index(channel_found[0]) 82 | channel_name_1 = channel_names[channel_loc_1] 83 | channel_fs_1 = fs_list[channel_loc_1] 84 | if len(channel_found) == 2: 85 | channel_loc_2 = channel_names.index(channel_found[1]) 86 | channel_name_2 = channel_names[channel_loc_2] 87 | channel_fs_2 = fs_list[channel_loc_2] 88 | else: 89 | channel_name_2 = "" 90 | channel_fs_2 = "" 91 | channel_str = "%s minus %s, fs %s minus %s" % ( 92 | channel_name_1, 93 | channel_name_2, 94 | channel_fs_1, 95 | channel_fs_2, 96 | ) 97 | channel_ids_list.append(channel_str) 98 | 99 | print("Epoch length:", np.unique(epoch_length_list)) 100 | print("First start:", np.unique(first_label_start_list)) 101 | print("Valid channels available:", np.unique(channel_ids_list)) 102 | -------------------------------------------------------------------------------- /nsrr/check_raw_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pprint import pprint 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | project_root = ".." 9 | sys.path.append(project_root) 10 | 11 | DATASETS_PATH = os.path.join(project_root, "resources", "datasets", "nsrr") 12 | 13 | 14 | if __name__ == "__main__": 15 | 16 | dataset_name_list = [ 17 | "chat1", 18 | ] 19 | 20 | # Keys in dataset: 21 | # 'dataset' 22 | # 'subject_id' 23 | # 'channel' 24 | # 'signal' 25 | # 'sampling_rate' 26 | # 'hypnogram' 27 | # 'epoch_duration' 28 | # 'bandpass_filter' 29 | # 'resampling_function' 30 | # 'original_sampling_rate' 31 | 32 | for dataset_name in dataset_name_list: 33 | print("Check %s" % dataset_name) 34 | 35 | metadata_dir = os.path.abspath( 36 | os.path.join(DATASETS_PATH, dataset_name, "datasets") 37 | ) 38 | all_files = os.listdir(metadata_dir) 39 | all_files.sort() 40 | 41 | # variables 42 | var_file = [ 43 | f 44 | for f in all_files 45 | if ("variables" in f) and ("dictionary" in f) and (".csv" in f) 46 | ] 47 | pprint(var_file) 48 | 49 | var_file = os.path.join(metadata_dir, var_file[0]) 50 | 51 | var_df = pd.read_csv(var_file) 52 | var_df = var_df[[("Demographics" in s) for s in var_df["folder"]]] 53 | print(var_df[["folder", "id", "display_name", "type"]]) 54 | 55 | # Dataset file 56 | dataset_file = [f for f in all_files if ("dataset" in f) and (".csv" in f)] 57 | print("Datasets:") 58 | pprint(dataset_file) 59 | 60 | # id_col = var_df['id'] 61 | # useful_id = [s for s in id_col if ('age' in s) or ('sex' in s) or ('gender' in s)] 62 | # 63 | # print(useful_id) 64 | # 65 | # # load dataset 66 | # metadata_path = os.path.join(metadata_dir, 'shhs1-dataset-0.14.0.csv') 67 | # meta_df = pd.read_csv(metadata_path) 68 | # names = meta_df.columns 69 | # useful_names = [n for n in names if 'age' in n] 70 | # print(useful_names) 71 | -------------------------------------------------------------------------------- /nsrr/nsrr_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | import numpy as np 5 | import pyedflib 6 | 7 | from sleeprnn.data import utils 8 | 9 | # This path is expected to be the parent path of the paths listed in NSRR_DATA_PATHS 10 | NSRR_PATH = os.path.abspath("/home/ntapia/Projects/Sleep_Databases/NSRR_Databases") 11 | 12 | NSRR_DATA_PATHS = { 13 | "shhs1": { 14 | "edf": os.path.join(NSRR_PATH, "shhs/polysomnography/edfs/shhs1"), 15 | "annot": os.path.join( 16 | NSRR_PATH, "shhs/polysomnography/annotations-events-nsrr/shhs1" 17 | ), 18 | }, 19 | "mros1": { 20 | "edf": os.path.join(NSRR_PATH, "mros/polysomnography/edfs/visit1"), 21 | "annot": os.path.join( 22 | NSRR_PATH, "mros/polysomnography/annotations-events-nsrr/visit1" 23 | ), 24 | }, 25 | "chat1": { 26 | "edf": os.path.join(NSRR_PATH, "chat/polysomnography/edfs/visit1"), 27 | "annot": os.path.join( 28 | NSRR_PATH, "chat/polysomnography/annotations-events-nsrr/visit1" 29 | ), 30 | }, 31 | "ccshs": { 32 | "edf": os.path.join(NSRR_PATH, "ccshs/polysomnography/edfs"), 33 | "annot": os.path.join( 34 | NSRR_PATH, "ccshs/polysomnography/annotations-events-nsrr" 35 | ), 36 | }, 37 | "cfs": { 38 | "edf": os.path.join(NSRR_PATH, "cfs/polysomnography/edfs"), 39 | "annot": os.path.join(NSRR_PATH, "cfs/polysomnography/annotations-events-nsrr"), 40 | }, 41 | "sof": { 42 | "edf": os.path.join(NSRR_PATH, "sof/polysomnography/edfs"), 43 | "annot": os.path.join(NSRR_PATH, "sof/polysomnography/annotations-events-nsrr"), 44 | }, 45 | } 46 | 47 | # We need to extract a single EEG signal per subject. The following list of channel 48 | # pairs is used to select which signal to extract. It's mostly based on the rule of 49 | # priotizing C3 derivations over C4 derivations, except for the SHHS dataset due to 50 | # noise in the C3 derivation. Tuples with two channel names are used to subtract the 51 | # second from the first. 52 | CHANNEL_PRIORITY_LABELS = [ 53 | ("EEG",), # C4-A1 in SHHS 54 | ("EEG(sec)",), # C3-A2 in SHHS 55 | ("EEG2",), 56 | ("EEG 2",), 57 | ("EEG(SEC)",), 58 | ("EEG sec",), 59 | ("C3", "A2"), 60 | ("C3", "M2"), 61 | ("C3-A2",), 62 | ("C3-M2",), 63 | ("C4", "A1"), 64 | ("C4", "M1"), 65 | ("C4-A1",), 66 | ("C4-M1",), 67 | ] 68 | 69 | 70 | def extract_id(fname, is_annotation): 71 | # Remove extension 72 | fname = ".".join(fname.split(".")[:-1]) 73 | if is_annotation: 74 | # Remove last tag 75 | fname = "-".join(fname.split("-")[:-1]) 76 | return fname 77 | 78 | 79 | def prepare_paths(edf_folder, annot_folder): 80 | """ 81 | Assuming annot_folder is the one in the NSRR format, 82 | nomenclature of files is 83 | edf: subjectid.edf 84 | xml: subjectid-nsrr.xml 85 | """ 86 | 87 | edf_files = os.listdir(edf_folder) 88 | edf_files = [f for f in edf_files if ".edf" in f] 89 | 90 | annot_files = os.listdir(annot_folder) 91 | annot_files = [f for f in annot_files if ".xml" in f] 92 | 93 | edf_ids = [extract_id(fname, False) for fname in edf_files] 94 | annot_ids = [extract_id(fname, True) for fname in annot_files] 95 | 96 | # Keep only IDs with both files 97 | common_ids = list(set(edf_ids).intersection(set(annot_ids))) 98 | common_ids.sort() 99 | 100 | paths_dict = {} 101 | for single_id in common_ids: 102 | edf_loc = edf_ids.index(single_id) 103 | annot_loc = annot_ids.index(single_id) 104 | paths_dict[single_id] = { 105 | "edf": os.path.join(edf_folder, edf_files[edf_loc]), 106 | "annot": os.path.join(annot_folder, annot_files[annot_loc]), 107 | } 108 | return paths_dict 109 | 110 | 111 | def read_hypnogram(annot_path, verbose=False, assumed_epoch_length_if_missing=30): 112 | tree = ET.parse(annot_path) 113 | root = tree.getroot() 114 | scored_events = root.find("ScoredEvents") 115 | epoch_length_text = root.find("EpochLength").text 116 | if epoch_length_text is None: 117 | ( 118 | print( 119 | "Missing epoch length, assuming %s [s]" 120 | % assumed_epoch_length_if_missing 121 | ) 122 | if verbose 123 | else None 124 | ) 125 | epoch_length = assumed_epoch_length_if_missing 126 | else: 127 | epoch_length = float(epoch_length_text) 128 | # print(ET.tostring(root, encoding='utf8').decode('utf8')) 129 | stage_labels = [] 130 | stage_stamps = [] 131 | for event in scored_events: 132 | e_type = event.find("EventType").text 133 | if e_type == "Stages|Stages": 134 | stage_name = event.find("EventConcept").text 135 | stage_start = float(event.find("Start").text) 136 | stage_duration = float(event.find("Duration").text) 137 | # Normalize variable-length epoch to a number of fixed-length epochs 138 | n_epochs = int(stage_duration / epoch_length) 139 | for i in range(n_epochs): 140 | stage_stamps.append([stage_start + epoch_length * i, epoch_length]) 141 | stage_labels.append(stage_name) 142 | stage_labels = np.array(stage_labels) 143 | stage_stamps = np.stack(stage_stamps, axis=0) 144 | idx_sorted = np.argsort(stage_stamps[:, 0]) 145 | stage_labels = stage_labels[idx_sorted] 146 | stage_stamps = stage_stamps[idx_sorted, :] 147 | stage_start_times = stage_stamps[:, 0].astype(np.float32) 148 | return stage_labels, stage_start_times, epoch_length 149 | 150 | 151 | def get_edf_info(edf_path): 152 | fs_list = [] 153 | with pyedflib.EdfReader(edf_path) as file: 154 | channel_names = file.getSignalLabels() 155 | for chn in channel_names: 156 | channel_to_extract = channel_names.index(chn) 157 | fs = file.samplefrequency(channel_to_extract) 158 | fs_list.append(fs) 159 | return channel_names, fs_list 160 | 161 | 162 | def read_signal_from_file(file, channel_name): 163 | units_to_factor_map = { 164 | "V": 1e6, 165 | "mV": 1e3, 166 | "uV": 1.0, 167 | } 168 | channel_names = file.getSignalLabels() 169 | channel_to_extract = channel_names.index(channel_name) 170 | 171 | signal = file.readSignal(channel_to_extract) 172 | units = file.getPhysicalDimension(channel_to_extract) 173 | factor = units_to_factor_map[units] 174 | signal = signal * factor 175 | fs = file.samplefrequency(channel_to_extract) 176 | return signal, fs 177 | 178 | 179 | def read_edf_channel(edf_path, channel_priority_list): 180 | with pyedflib.EdfReader(edf_path) as file: 181 | channel_names = file.getSignalLabels() 182 | 183 | channel_found = None 184 | for chn_pair in channel_priority_list: 185 | if np.all([chn in channel_names for chn in chn_pair]): 186 | channel_found = chn_pair 187 | break 188 | if channel_found is None: 189 | return None 190 | 191 | signal, fs = read_signal_from_file(file, channel_found[0]) 192 | if len(channel_found) == 2: 193 | signal_2, fs_2 = read_signal_from_file(file, channel_found[1]) 194 | if fs != fs_2: 195 | return None 196 | signal = signal - signal_2 197 | return signal, fs, channel_found 198 | 199 | 200 | def short_signal_to_n2(signal, hypnogram, epoch_samples, n2_name): 201 | """ 202 | Returns a cropped signal where only N2 stages are returned, ensuring one page of real signal 203 | at each border. This means that some non-N2 stages are kept, but they are a small portion. 204 | """ 205 | n2_pages = np.where(hypnogram == n2_name)[0] 206 | valid_pages = np.concatenate([n2_pages - 1, n2_pages, n2_pages + 1]) 207 | valid_pages = np.clip(valid_pages, a_min=0, a_max=(hypnogram.size - 1)) 208 | valid_pages = np.unique( 209 | valid_pages 210 | ) # it is ensured to have context at each side of n2 pages 211 | 212 | # Now simplify 213 | hypnogram = hypnogram[valid_pages] 214 | 215 | signal = utils.extract_pages(signal, valid_pages, epoch_samples) 216 | signal = signal.flatten() 217 | 218 | return signal, hypnogram 219 | -------------------------------------------------------------------------------- /nsrr/prepare_data.py: -------------------------------------------------------------------------------- 1 | """Generates preprocessed subject data from NSRR datasets. 2 | 3 | It reads the raw files from each NSRR dataset, extracts the EEG signal and hypnogram, 4 | resamples the signal to 200 Hz, applies a bandpass filter, and saves the preprocessed 5 | data in a .npz file, one per subject. 6 | """ 7 | 8 | import os 9 | import sys 10 | import time 11 | 12 | import numpy as np 13 | from scipy.signal import correlate 14 | 15 | project_root = ".." 16 | sys.path.append(project_root) 17 | 18 | from nsrr import nsrr_utils 19 | from nsrr.nsrr_utils import NSRR_DATA_PATHS, CHANNEL_PRIORITY_LABELS 20 | from sleeprnn.data import utils 21 | 22 | 23 | DATASETS_PATH = os.path.join(project_root, "resources", "datasets", "nsrr") 24 | 25 | 26 | def get_maximum_correlation_by_alignment(x, y): 27 | if x.size != y.size: 28 | raise ValueError("Signals of different sizes") 29 | x_std = x.std() 30 | y_std = y.std() 31 | if x_std == 0 or y_std == 0: 32 | return 2 33 | x = (x - x.mean()) / x_std 34 | y = (y - y.mean()) / y_std 35 | possible_corrcoefs = correlate(x, y, mode="same") / x.size 36 | max_corrcoef = np.max(np.abs(possible_corrcoefs)) 37 | return max_corrcoef 38 | 39 | 40 | if __name__ == "__main__": 41 | 42 | keep_only_n2 = True # If True, only N2 epochs are kept. Allows saving resources. 43 | 44 | # Set which dataset to process by this script: 45 | dataset_name_list = [ 46 | "shhs1", 47 | ] 48 | 49 | # This flag is for debugging purposes. If None, all subjects are processed. 50 | # If an integer is given, only that number of subjects is processed. 51 | reduced_number_of_subjects = None 52 | 53 | # ################## 54 | # AUXILIARY VARIABLES 55 | 56 | unknown_stage_label = "?" 57 | n2_id = "Stage 2 sleep|2" 58 | target_fs = 200 # Hz 59 | 60 | # ################## 61 | # READ 62 | 63 | for dataset_name in dataset_name_list: 64 | 65 | save_dir = os.path.abspath( 66 | os.path.join(DATASETS_PATH, dataset_name, "register_and_state") 67 | ) 68 | os.makedirs(save_dir, exist_ok=True) 69 | 70 | print("\nReading %s" % dataset_name) 71 | edf_folder = NSRR_DATA_PATHS[dataset_name]["edf"] 72 | annot_folder = NSRR_DATA_PATHS[dataset_name]["annot"] 73 | print("From paths:") 74 | print(edf_folder) 75 | print(annot_folder) 76 | paths_dict = nsrr_utils.prepare_paths(edf_folder, annot_folder) 77 | subject_ids = list(paths_dict.keys()) 78 | 79 | if reduced_number_of_subjects is not None: 80 | # Reduced subset 81 | subject_ids = subject_ids[:reduced_number_of_subjects] 82 | 83 | n_subjects = len(subject_ids) 84 | print("Retrieved subjects: %d" % n_subjects) 85 | print("Preprocessed files will be saved at %s" % save_dir) 86 | 87 | start_time = time.time() 88 | for i_sub, subject_id in enumerate(subject_ids): 89 | print( 90 | "\nProcessing subject %s (%04d/%d)" 91 | % (subject_id, i_sub + 1, n_subjects) 92 | ) 93 | 94 | # Read data 95 | stage_labels, stage_start_times, epoch_length = nsrr_utils.read_hypnogram( 96 | paths_dict[subject_id]["annot"] 97 | ) 98 | 99 | if "shhs" in dataset_name: 100 | # SHHS specific processing 101 | first_eeg_names = [("EEG",)] # C4-A1 in SHHS 102 | second_eeg_names = [ 103 | ("EEG(sec)",), # C3-A2 in SHHS 104 | ("EEG2",), 105 | ("EEG 2",), 106 | ("EEG(SEC)",), 107 | ("EEG sec",), 108 | ] 109 | cardiac_names = [("ECG",)] 110 | print("ECG correlation computation") 111 | signal_a, fs_a, channel_found_a = nsrr_utils.read_edf_channel( 112 | paths_dict[subject_id]["edf"], first_eeg_names 113 | ) 114 | signal_b, fs_b, channel_found_b = nsrr_utils.read_edf_channel( 115 | paths_dict[subject_id]["edf"], second_eeg_names 116 | ) 117 | signal_cardiac, fs_cardiac, _ = nsrr_utils.read_edf_channel( 118 | paths_dict[subject_id]["edf"], cardiac_names 119 | ) 120 | fs_cardiac = int(np.round(fs_cardiac)) 121 | fs_a = int(np.round(fs_a)) 122 | fs_b = int(np.round(fs_b)) 123 | if fs_cardiac != fs_a: 124 | print( 125 | "Resampling cardiac signal from %s Hz to %s Hz" 126 | % (fs_cardiac, fs_a) 127 | ) 128 | signal_cardiac = utils.resample_signal( 129 | signal_cardiac, fs_old=fs_cardiac, fs_new=fs_a 130 | ) 131 | # generate short signals (N2 only) 132 | tmp_epoch_samples = int(epoch_length * fs_a) 133 | valid_starts = stage_start_times[stage_labels == n2_id] 134 | valid_pages = (valid_starts / epoch_length).astype(np.int32) 135 | last_sample_valid = int((valid_pages[-1] + 1) * tmp_epoch_samples) 136 | tmp_signal_a = ( 137 | signal_a[:last_sample_valid] 138 | .reshape(-1, tmp_epoch_samples)[valid_pages] 139 | .flatten() 140 | ) 141 | tmp_signal_b = ( 142 | signal_b[:last_sample_valid] 143 | .reshape(-1, tmp_epoch_samples)[valid_pages] 144 | .flatten() 145 | ) 146 | tmp_signal_cardiac = ( 147 | signal_cardiac[:last_sample_valid] 148 | .reshape(-1, tmp_epoch_samples)[valid_pages] 149 | .flatten() 150 | ) 151 | # measure correlation 152 | corr_a = get_maximum_correlation_by_alignment( 153 | tmp_signal_a, tmp_signal_cardiac 154 | ) 155 | corr_b = get_maximum_correlation_by_alignment( 156 | tmp_signal_b, tmp_signal_cardiac 157 | ) 158 | print( 159 | "Correlations -- EEG: %1.4f -- EEG(sec): %1.4f" % (corr_a, corr_b) 160 | ) 161 | std_a = tmp_signal_a.std() 162 | std_b = tmp_signal_b.std() 163 | if (corr_b < corr_a) and (std_b > 5): 164 | signal, fs, channel_found = signal_b, fs_b, channel_found_b 165 | elif std_a > 5: 166 | signal, fs, channel_found = signal_a, fs_a, channel_found_a 167 | else: 168 | raise ValueError("Both std less than 5") 169 | print("%s selected." % channel_found[0]) 170 | else: 171 | signal, fs, channel_found = nsrr_utils.read_edf_channel( 172 | paths_dict[subject_id]["edf"], CHANNEL_PRIORITY_LABELS 173 | ) 174 | 175 | # Channel id 176 | channel_id = " minus ".join(channel_found) 177 | 178 | # Filter and resample 179 | original_sampling_rate = fs 180 | 181 | # Transform the original fs frequency with decimals to rounded version if necessary 182 | fs_round = int(np.round(fs)) 183 | if np.abs(fs_round - fs) > 1e-8: 184 | print("Linear interpolation from %s Hz to %s Hz" % (fs, fs_round)) 185 | signal = utils.resample_signal_linear( 186 | signal, fs_old=fs, fs_new=fs_round 187 | ) 188 | fs = fs_round 189 | 190 | # Broad bandpass filter to signal 191 | signal = utils.broad_filter(signal, fs, lowcut=0.1, highcut=35) 192 | # Now resample to the required frequency 193 | if fs != target_fs: 194 | print( 195 | "Resampling channel %s from %s Hz to required %s Hz" 196 | % (channel_id, fs, target_fs) 197 | ) 198 | signal = utils.resample_signal(signal, fs_old=fs, fs_new=target_fs) 199 | resample_method = "scipy.signal.resample_poly" 200 | else: 201 | print( 202 | "Signal channel %s already at required %s Hz" 203 | % (channel_id, target_fs) 204 | ) 205 | resample_method = "none" 206 | fs = target_fs 207 | 208 | # Ensure first label starts at t = 0 209 | valid_start_sample = int(stage_start_times[0] * fs) 210 | signal = signal[valid_start_sample:] 211 | stage_start_times = stage_start_times - stage_start_times[0] 212 | 213 | # Fill hypnogram if necessary 214 | hypno_total_pages = int( 215 | (stage_start_times[-1] + epoch_length) / epoch_length 216 | ) 217 | hypnogram = np.array( 218 | [unknown_stage_label] * hypno_total_pages, dtype=stage_labels.dtype 219 | ) 220 | labeled_locs = (stage_start_times / epoch_length).astype(np.int32) 221 | hypnogram[labeled_locs] = stage_labels 222 | 223 | # Ensure that both the signal and the hypnogram end at the same time 224 | epoch_samples = int(epoch_length * fs) 225 | signal_total_full_pages = int(signal.size // epoch_samples) 226 | valid_total_pages = min(hypno_total_pages, signal_total_full_pages) 227 | valid_total_samples = int(valid_total_pages * epoch_samples) 228 | hypnogram = hypnogram[:valid_total_pages] 229 | signal = signal[:valid_total_samples] 230 | 231 | if keep_only_n2: 232 | epoch_samples = int(epoch_length * fs) 233 | signal, hypnogram = nsrr_utils.short_signal_to_n2( 234 | signal, hypnogram, epoch_samples, n2_id 235 | ) 236 | 237 | # Save subject data 238 | subject_data_dict = { 239 | "dataset": dataset_name, 240 | "subject_id": subject_id, 241 | "channel": channel_id, 242 | "signal": signal.astype(np.float32), 243 | "sampling_rate": fs, 244 | "hypnogram": hypnogram, 245 | "epoch_duration": epoch_length, 246 | "bandpass_filter": "scipy.signal.butter, 0.1-35Hz, order 3", 247 | "resampling_function": resample_method, 248 | "original_sampling_rate": original_sampling_rate, 249 | } 250 | fpath = os.path.join(save_dir, "%s.npz" % subject_id) 251 | np.savez(fpath, **subject_data_dict) 252 | 253 | elapsed_time = time.time() - start_time 254 | print("E.T. %1.4f [s]" % elapsed_time) 255 | -------------------------------------------------------------------------------- /nsrr/prepare_metadata.py: -------------------------------------------------------------------------------- 1 | """Generates preprocessed metadata (age and sex) from NSRR datasets. 2 | 3 | It creates a csv file with columns subject_id, age and sex, one row per subject, 4 | one file per dataset. This file allows analyzing the distribution of per-subject 5 | sleep spindle features by sex and age. 6 | """ 7 | 8 | import os 9 | import sys 10 | from pprint import pprint 11 | 12 | import numpy as np 13 | import pandas as pd 14 | 15 | project_root = ".." 16 | sys.path.append(project_root) 17 | 18 | DATASETS_PATH = os.path.join(project_root, "resources", "datasets", "nsrr") 19 | 20 | 21 | def prepare_suffix(subject_id, dataset_name): 22 | if dataset_name == "mros1": 23 | return subject_id.lower() 24 | elif dataset_name == "sof": 25 | return "%05d" % subject_id 26 | else: 27 | return str(subject_id) 28 | 29 | 30 | if __name__ == "__main__": 31 | 32 | configs = [ 33 | dict( 34 | dataset_name="shhs1", 35 | filename="shhs1-dataset-0.14.0.csv", 36 | subject_id="nsrrid", 37 | age="age_s1", 38 | sex="gender", 39 | sex_map={1: "m", 2: "f"}, 40 | prefix="shhs1-", 41 | ), 42 | dict( 43 | dataset_name="mros1", 44 | filename="mros-visit1-dataset-0.5.0.csv", 45 | subject_id="nsrrid", 46 | age="vsage1", 47 | sex="gender", 48 | sex_map={2: "m"}, 49 | prefix="mros-visit1-", 50 | ), 51 | dict( 52 | dataset_name="chat1", 53 | filename="chat-baseline-dataset-0.11.0.csv", 54 | subject_id="nsrrid", 55 | age="ageyear_at_meas", 56 | sex="chi2", 57 | sex_map={1: "m", 2: "f"}, 58 | prefix="chat-baseline-", 59 | ), 60 | dict( 61 | dataset_name="chat1", 62 | filename="chat-nonrandomized-dataset-0.11.0.csv", 63 | subject_id="nsrrid", 64 | age="age_nr", 65 | sex="ref9", 66 | sex_map={1: "m", 2: "f"}, 67 | prefix="chat-baseline-nonrandomized-", 68 | ), 69 | dict( 70 | dataset_name="sof", 71 | filename="sof-visit-8-dataset-0.6.0.csv", 72 | subject_id="sofid", 73 | age="V8AGE", 74 | sex="gender", 75 | sex_map={1: "f"}, 76 | prefix="sof-visit-8-", 77 | ), 78 | dict( 79 | dataset_name="cfs", 80 | filename="cfs-visit5-dataset-0.5.0.csv", 81 | subject_id="nsrrid", 82 | age="age", 83 | sex="SEX", 84 | sex_map={0: "f", 1: "m"}, 85 | prefix="cfs-visit5-", 86 | ), 87 | dict( 88 | dataset_name="ccshs", 89 | filename="ccshs-trec-dataset-0.6.0.csv", 90 | subject_id="nsrrid", 91 | age="age", 92 | sex="male", 93 | sex_map={0: "f", 1: "m"}, 94 | prefix="ccshs-trec-", 95 | ), 96 | ] 97 | 98 | for config in configs: 99 | print("\nProcessing %s" % config["prefix"]) 100 | metadata_file = os.path.join( 101 | DATASETS_PATH, config["dataset_name"], "datasets", config["filename"] 102 | ) 103 | meta_df = pd.read_csv(metadata_file, low_memory=False) 104 | # subject id 105 | subject_ids = meta_df[config["subject_id"]].values 106 | subject_ids = [ 107 | prepare_suffix(sub_id, config["dataset_name"]) for sub_id in subject_ids 108 | ] 109 | subject_ids = np.array( 110 | ["%s%s" % (config["prefix"], sub_id) for sub_id in subject_ids], 111 | dtype=" valid_range[1] or value < valid_range[0]: 13 | msg = "Expected range %s for %s, but %s was provided." % ( 14 | valid_range, 15 | name, 16 | value, 17 | ) 18 | raise ValueError(msg) 19 | 20 | 21 | def check_valid_value(value, name, valid_list): 22 | """Raises a ValueError exception if value not in valid_list""" 23 | if value not in valid_list: 24 | msg = "Expected %s for %s, but %s was provided." % (valid_list, name, value) 25 | raise ValueError(msg) 26 | 27 | 28 | def check_directory(path_dir): 29 | """Raises FileNotFoundError exception if directory doesn't exists""" 30 | if not os.path.isdir(path_dir): 31 | raise FileNotFoundError("Directory not found: %s" % path_dir) 32 | 33 | 34 | def ensure_directory(path_dir): 35 | """If directory doesn't exists, is created.""" 36 | os.makedirs(path_dir, exist_ok=True) 37 | -------------------------------------------------------------------------------- /sleeprnn/common/constants.py: -------------------------------------------------------------------------------- 1 | """constants.py: Module that stores several useful constants for the project.""" 2 | 3 | # Dataset name 4 | MASS_SS_NAME = "mass_ss" 5 | MASS_KC_NAME = "mass_kc" 6 | INTA_SS_NAME = "inta_ss" 7 | MODA_SS_NAME = "moda_ss" 8 | CAP_SS_NAME = "cap_ss" 9 | PINK_NAME = "pink_nn" 10 | NSRR_SS_NAME = "nsrr_ss" 11 | 12 | # Database split 13 | TRAIN_SUBSET = "train" 14 | VAL_SUBSET = "val" 15 | TEST_SUBSET = "test" 16 | ALL_TRAIN_SUBSET = "all_train" 17 | 18 | # Task mode 19 | N2_RECORD = "n2" 20 | WN_RECORD = "wn" 21 | 22 | # Event names 23 | SPINDLE = "spindle" 24 | KCOMPLEX = "kcomplex" 25 | 26 | # Metric keys 27 | AF1 = "af1" 28 | F1_SCORE = "f1_score" 29 | PRECISION = "precision" 30 | RECALL = "recall" 31 | TP = "tp" 32 | FP = "fp" 33 | FN = "fn" 34 | MEAN_ALL_IOU = "mean_all_iou" 35 | MEAN_NONZERO_IOU = "mean_nonzero_iou" 36 | MACRO_AVERAGE = "macro_average" 37 | MICRO_AVERAGE = "micro_average" 38 | 39 | # Baselines data keys 40 | F1_VS_IOU = "f1_vs_iou" 41 | RECALL_VS_IOU = "recall_vs_iou" 42 | PRECISION_VS_IOU = "precision_vs_iou" 43 | IOU_HIST_BINS = "iou_hist_bins" 44 | IOU_CURVE_AXIS = "iou_curve_axis" 45 | IOU_HIST_VALUES = "iou_hist_values" 46 | MEAN_IOU = "mean_iou" 47 | MEAN_AF1 = "mean_af1" 48 | IQR_LOW_IOU = "iqr_low_iou" 49 | IQR_HIGH_IOU = "iqr_high_iou" 50 | 51 | # Type of padding 52 | PAD_SAME = "same" 53 | PAD_VALID = "valid" 54 | 55 | # Type of batch normalization 56 | BN = "bn" 57 | BN_RENORM = "bn_renorm" 58 | 59 | # Type of pooling 60 | MAXPOOL = "maxpool" 61 | AVGPOOL = "avgpool" 62 | 63 | # Alternative to pooling after convolution 64 | STRIDEDCONV = "stridedconv" 65 | 66 | # Type of dropout 67 | REGULAR_DROP = "regular_dropout" 68 | SEQUENCE_DROP = "sequence_dropout" 69 | 70 | # Number of directions for recurrent layers 71 | UNIDIRECTIONAL = "unidirectional" 72 | BIDIRECTIONAL = "bidirectional" 73 | 74 | # Type of class weights 75 | BALANCED = "balanced" 76 | BALANCED_DROP = "balanced_drop" 77 | BALANCED_DROP_V2 = "balanced_drop_v2" 78 | 79 | # Types of losses 80 | CROSS_ENTROPY_LOSS = "cross_entropy_loss" 81 | DICE_LOSS = "dice_loss" 82 | FOCAL_LOSS = "focal_loss" 83 | WORST_MINING_LOSS = "worst_mining_loss" 84 | WORST_MINING_V2_LOSS = "worst_mining_v2_loss" 85 | CROSS_ENTROPY_NEG_ENTROPY_LOSS = "cross_entropy_neg_entropy_loss" 86 | CROSS_ENTROPY_SMOOTHING_LOSS = "cross_entropy_smoothing_loss" 87 | CROSS_ENTROPY_HARD_CLIP_LOSS = "cross_entropy_hard_clip_loss" 88 | CROSS_ENTROPY_SMOOTHING_CLIP_LOSS = "cross_entropy_smoothing_clip_loss" 89 | MOD_FOCAL_LOSS = "mod_focal_loss" 90 | CROSS_ENTROPY_BORDERS_LOSS = "cross_entropy_borders_loss" 91 | CROSS_ENTROPY_BORDERS_IND_LOSS = "cross_entropy_borders_ind_loss" 92 | WEIGHTED_CROSS_ENTROPY_LOSS = "weighted_cross_entropy_loss" 93 | WEIGHTED_CROSS_ENTROPY_LOSS_HARD = "weighted_cross_entropy_loss_hard" 94 | WEIGHTED_CROSS_ENTROPY_LOSS_SOFT = "weighted_cross_entropy_loss_soft" 95 | WEIGHTED_CROSS_ENTROPY_LOSS_V2 = "weighted_cross_entropy_loss_v2" 96 | WEIGHTED_CROSS_ENTROPY_LOSS_V3 = "weighted_cross_entropy_loss_v3" 97 | WEIGHTED_CROSS_ENTROPY_LOSS_V4 = "weighted_cross_entropy_loss_v4" 98 | HINGE_LOSS = "hinge_loss" 99 | WEIGHTED_CROSS_ENTROPY_LOSS_V5 = "weighted_cross_entropy_loss_v5" 100 | CROSS_ENTROPY_LOSS_WITH_LOGITS_REG = "cross_entropy_loss_with_logits_reg" 101 | CROSS_ENTROPY_LOSS_WITH_SELF_SUPERVISION = "cross_entropy_loss_with_self_supervision" 102 | MASKED_SOFT_FOCAL_LOSS = "masked_soft_focal_loss" 103 | 104 | # Types of logits reg 105 | LOGITS_REG_NORM = "logits_reg_norm" 106 | LOGITS_REG_NORM_SQRT = "logits_reg_norm_sqrt" 107 | LOGITS_REG_ATTRACTOR = "logits_reg_attractor" 108 | LOGITS_REG_ATTRACTOR_SQRT = "logits_reg_attractor_sqrt" 109 | 110 | # Mix weight strategies 111 | MIX_WEIGHTS_SUM = "mix_weights_sum" 112 | MIX_WEIGHTS_PRODUCT = "mix_weights_product" 113 | MIX_WEIGHTS_MAX = "mix_weights_max" 114 | 115 | # Types of optimizer 116 | ADAM_OPTIMIZER = "adam_optimizer" 117 | ADAM_W_OPTIMIZER = "adam_w_optimizer" 118 | SGD_OPTIMIZER = "sgd_optimizer" 119 | RMSPROP_OPTIMIZER = "rmsprop_optimizer" 120 | 121 | # Training params 122 | LOSS_CRITERION = "loss_criterion" 123 | METRIC_CRITERION = "metric_criterion" 124 | 125 | # Normalization computation mode 126 | NORM_IQR = "norm_iqr" 127 | NORM_STD = "norm_std" 128 | NORM_GLOBAL = "norm_global" 129 | 130 | # Type of masking for wave augmentation 131 | MASK_KEEP_EVENTS = "mask_keep_events" 132 | MASK_KEEP_BACKGROUND = "mask_keep_background" 133 | MASK_NONE = "mask_none" 134 | 135 | # Type of dispersion mode for A7 features 136 | DISPERSION_STD = "std" 137 | DISPERSION_MADE = "made" 138 | DISPERSION_STD_ROBUST = "std_robust" 139 | 140 | # Colors 141 | RED = "red" 142 | BLUE = "blue" 143 | GREEN = "green" 144 | GREY = "grey" 145 | DARK = "dark" 146 | CYAN = "cyan" 147 | PURPLE = "purple" 148 | 149 | # Model versions 150 | DUMMY = "dummy" 151 | DEBUG = "debug" 152 | V1 = "v1" 153 | V4 = "v4" 154 | V5 = "v5" 155 | V6 = "v6" 156 | V7 = "v7" 157 | V8 = "v8" 158 | V9 = "v9" 159 | V7lite = "v7lite" 160 | V7litebig = "v7litebig" 161 | V10 = "v10" 162 | V11 = "v11" # Time-domain 163 | V12 = "v12" 164 | V13 = "v13" 165 | V14 = "v14" 166 | V15 = "v15" # Mixed 167 | V16 = "v16" # Mixed 168 | V17 = "v17" 169 | V18 = "v18" # Mixed 170 | V19 = "v19" 171 | V20_INDEP = "v20_indep" # time 172 | V20_CONCAT = "v20_concat" # time 173 | V21 = "v21" # Mixed 174 | V22 = "v22" # CWT indep 175 | V23 = "v23" # Time-domain with LSTM instead of FC 176 | V24 = "v24" # Time-domain with feed-forward UpConv output 177 | V25 = "v25" # Time-domain unet 178 | V11_SKIP = "v11_skip" 179 | V19_SKIP = "v19_skip" 180 | V19_SKIP2 = "v19_skip2" 181 | V19_SKIP3 = "v19_skip3" 182 | V26 = "v26" # Experimental skip (based on v19) 183 | V27 = "v27" # Experimental skip (based on v11) 184 | V28 = "v28" # Experimental skip (based on v11) 185 | V29 = "v29" # Experimental skip (based on v11) 186 | V30 = "v30" # Experimental skip (based on v11) 187 | V115 = "v115" # v11 with kernel 5 188 | V195 = "v195" # v19 with kernel 5 189 | V11G = "v11g" # v11 with GRU instead of LSTM 190 | V19G = "v19g" # V19 with GRU instead of LSTM 191 | V31 = "v31" # v19 with independent branches for each band and 2 convs 192 | V32 = "v32" # v19 with independent branches for each band and 3 convs 193 | V19P = "v19p" # v19 with conv1x1 projection before lstm 194 | V33 = "v33" # V19 with independent LSTM's in first layer 195 | V34 = "v34" # V19 with 1D convolutions (frequencies as channels) 196 | ATT01 = "att01" # 1d attention basic 197 | ATT02 = "att02" # 1d attention with lstm 198 | ATT03 = "att03" # 1d attention with lstm (better) 199 | ATT04 = "att04" # 1d attention with 2 lstm 200 | ATT04C = "att04c" # 1d attention with 2 lstm and concat PE. 201 | V35 = "v35" # RED-Time+CWT (at last FC) 202 | V11_ABLATION = "v11_ablation" # BN and Dropout ablation 203 | V11_ABLATION_SCALED = "v11_ablation_scaled" # BN and Dropout ablation with scaled input 204 | V11_D6K5 = "v11_d6k5" 205 | V11_D8K3 = "v11_d8k3" 206 | V11_D8K5 = "v11_d8k5" 207 | V11_OUTRES = "v11_outres" 208 | V11_OUTPLUS = "v11_outplus" 209 | V11_SHIELD = "v11_shield" 210 | V11_LITE = "v11_lite" 211 | V11_NORM = "v11_norm" 212 | V11_PR_1 = "v11_pr_1" 213 | V11_PR_2P = "v11_pr_2p" 214 | V11_PR_2C = "v11_pr_2c" 215 | V11_PR_3P = "v11_pr_3p" 216 | V11_PR_3C = "v11_pr_3c" 217 | V11_LLC_STFT = "v11_llc_stft" # All conv blocks 218 | V11_LLC_STFT_1 = "v11_llc_stft_1" # 3rd conv block 219 | V11_LLC_STFT_2 = "v11_llc_stft_2" # 1st fc 220 | V11_LLC_STFT_3 = "v11_llc_stft_3" # logits 221 | V19_LLC_STFT_2 = "v19_llc_stft_2" # 1st fc 222 | V19_LLC_STFT_3 = "v19_llc_stft_3" # logits 223 | TCN01 = "tcn01" # TCN time and generic block design. 224 | TCN02 = "tcn02" # TCN time and dilations go up and then down 225 | TCN03 = "tcn03" # TCN time without residual blocks. 226 | TCN04 = "tcn04" # TCN time without residual blocks, with dilations up-down. 227 | V19_FROZEN = "v19_frozen" # BN at CWT is replaced by fixed normalization. 228 | ATT05 = "att05" # 2d attention with 2 lstm 229 | V19_VAR = "v19_var" # v19 with pooled scales at cwt 230 | V19_NOISY = "v19_noisy" # v19 but with uniform noise at scales 231 | A7_V1 = "a7_v1" # cnn DeepA7 model 232 | A7_V2 = "a7_v2" # lstm DeepA7 model 233 | A7_V3 = "a7_v3" # RED-A7 model 234 | V11_BP = "v11_bp" # band-passed input 235 | V19_BP = "v19_bp" # band-passed input 236 | V11_LN = "v11_ln" # zscore at last conv 237 | V11_LN2 = "v11_ln2" # zscore at logits 238 | V19_LN2 = "v19_ln2" # zscore at logits 239 | V11_LN3 = "v11_ln3" # zscore at both last conv and logits 240 | V11_MK = "v11_mk" # V11 but with multiple kernel sizes at each conv 241 | V11_MKD = "v11_mkd" # V11 but with multiple dilations at each conv 242 | V11_MKD2 = ( 243 | "v11_mkd2" # V11 but with multiple dilations at each conv, border crop after lstm 244 | ) 245 | V11_MKD2_STATMOD = "v11_mkd2_statmod" # V11_MKD2 but with stat net modulation 246 | V11_MKD2_STATDOT = ( 247 | "v11_mkd2_statdot" # V11_MKD2 but with stat net dot product for class scores 248 | ) 249 | V36 = ( 250 | "v36" # Based on RED-Time but with bandpass filtering and independent conv branches 251 | ) 252 | V11_ATT = "v11_att" # RED-Time with attention at output layer 253 | V11_MKD2_EXPERTMOD = "v11_mkd2_expertmod" # V11_MKD2 but with expert modulation 254 | V11_MKD2_EXPERTREG = "v11_mkd2_expertreg" # V11_MKD2 but with expert regression 255 | V11_MKD2_SWISH = "v11_mkd2_swish" # v11-mkd2 but with swish activation instead of relu 256 | V41 = "v41" # BIGGER NET (1D residual stages, possibly with dilations at last stage) 257 | V42 = "v42" # v41 but with self att instead of lstm 258 | V43 = "v43" # a lego extension of v41/v42 to support additional options for BigNet 259 | V2_TIME = "v2_time" # REDv2-Time 260 | V2_CWT1D = "v2_cwt1d" # REDv2-CWT1D 261 | V2_CWT2D = "v2_cwt2d" # REDv2-CWT2D 262 | -------------------------------------------------------------------------------- /sleeprnn/common/viz.py: -------------------------------------------------------------------------------- 1 | from IPython.core.display import display, HTML 2 | 3 | from sleeprnn.common import constants 4 | 5 | PALETTE = { 6 | constants.RED: "#c62828", 7 | constants.GREY: "#455a64", 8 | constants.BLUE: "#0277bd", 9 | constants.GREEN: "#43a047", 10 | constants.DARK: "#1b2631", 11 | constants.CYAN: "#00838F", 12 | constants.PURPLE: "#8E24AA", 13 | } 14 | 15 | GREY_COLORS = { 16 | 0: "#fafafa", 17 | 1: "#f5f5f5", 18 | 2: "#eeeeee", 19 | 3: "#e0e0e0", 20 | 4: "#bdbdbd", 21 | 5: "#9e9e9e", 22 | 6: "#757575", 23 | 7: "#616161", 24 | 8: "#424242", 25 | 9: "#212121", 26 | } 27 | 28 | COMPARISON_COLORS = { 29 | "model": PALETTE["cyan"], 30 | "expert": PALETTE["blue"], 31 | "baseline": GREY_COLORS[7], 32 | } 33 | 34 | BASELINES_LABEL_MARKER = { 35 | "2019_chambon": ("DOSED", "s"), 36 | "2019_lacourse": ("A7", "d"), 37 | "2017_lajnef": ("Spinky", "p"), 38 | } 39 | 40 | DPI = 200 41 | FONTSIZE_TITLE = 9 42 | FONTSIZE_GENERAL = 8 43 | AXIS_COLOR = GREY_COLORS[8] 44 | LINEWIDTH = 1.1 45 | MARKERSIZE = 5 46 | LEGEND_LABEL_SPACING = 1.1 47 | 48 | 49 | def notebook_full_width(): 50 | display(HTML("")) 51 | -------------------------------------------------------------------------------- /sleeprnn/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolasigor/Sleep-EEG-Event-Detector/24322278e3f3ef7535413a65a61fffc3ce4f4e01/sleeprnn/data/__init__.py -------------------------------------------------------------------------------- /sleeprnn/data/mass_raw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pyedflib 3 | import numpy as np 4 | 5 | from sleeprnn.data import utils 6 | 7 | PATH_REC = "register" 8 | PATH_MARKS = os.path.join("label", "spindle") 9 | PATH_STATES = os.path.join("label", "state") 10 | KEY_FILE_EEG = "file_eeg" 11 | KEY_FILE_STATES = "file_states" 12 | KEY_FILE_MARKS = "file_marks" 13 | IDS_INVALID = [4, 8, 15, 16] 14 | IDS_TEST = [2, 6, 12, 13] 15 | 16 | 17 | class MassRaw(object): 18 | def __init__(self): 19 | self.fs = 256 20 | self.page_duration = 20 21 | self.page_size = int(self.page_duration * self.fs) 22 | self.channel = "EEG C3-CLE" 23 | self.state_ids = np.array(["1", "2", "3", "4", "R", "W", "?"]) 24 | self.unknown_id = "?" # Character for unknown state in hypnogram 25 | self.n2_id = "2" # Character for N2 identification in hypnogram 26 | valid_ids = [i for i in range(1, 20) if i not in IDS_INVALID] 27 | self.test_ids = IDS_TEST 28 | self.train_ids = [i for i in valid_ids if i not in self.test_ids] 29 | self.dataset_dir = os.path.abspath(os.path.join(utils.PATH_DATA, "mass")) 30 | self.all_ids = self.train_ids + self.test_ids 31 | self.dataset_name = "mass_raw" 32 | 33 | def _get_file_paths(self): 34 | """Returns a list of dicts containing paths to load the database.""" 35 | # Build list of paths 36 | data_paths = {} 37 | for subject_id in self.all_ids: 38 | path_eeg_file = os.path.join( 39 | self.dataset_dir, PATH_REC, "01-02-%04d PSG.edf" % subject_id 40 | ) 41 | path_states_file = os.path.join( 42 | self.dataset_dir, PATH_STATES, "01-02-%04d Base.edf" % subject_id 43 | ) 44 | path_marks_1_file = os.path.join( 45 | self.dataset_dir, PATH_MARKS, "01-02-%04d SpindleE1.edf" % subject_id 46 | ) 47 | path_marks_2_file = os.path.join( 48 | self.dataset_dir, PATH_MARKS, "01-02-%04d SpindleE2.edf" % subject_id 49 | ) 50 | # Save paths 51 | ind_dict = { 52 | KEY_FILE_EEG: path_eeg_file, 53 | KEY_FILE_STATES: path_states_file, 54 | "%s_1" % KEY_FILE_MARKS: path_marks_1_file, 55 | "%s_2" % KEY_FILE_MARKS: path_marks_2_file, 56 | } 57 | # Check paths 58 | for key in ind_dict: 59 | if not os.path.isfile(ind_dict[key]): 60 | print("File not found: %s" % ind_dict[key]) 61 | data_paths[subject_id] = ind_dict 62 | return data_paths 63 | 64 | def get_subject_data(self, subject_id): 65 | data_paths = self._get_file_paths() 66 | path_dict = data_paths[subject_id] 67 | signal = self._read_eeg(path_dict[KEY_FILE_EEG]) 68 | hypnogram, start_sample = self._read_states_raw(path_dict[KEY_FILE_STATES]) 69 | signal, hypnogram, end_sample = self._fix_signal_and_states( 70 | signal, hypnogram, start_sample 71 | ) 72 | return signal, hypnogram 73 | 74 | def _read_eeg(self, path_eeg_file): 75 | """Loads signal from 'path_eeg_file', does filtering and resampling.""" 76 | with pyedflib.EdfReader(path_eeg_file) as file: 77 | channel_names = file.getSignalLabels() 78 | channel_to_extract = channel_names.index(self.channel) 79 | signal = file.readSignal(channel_to_extract) 80 | fs_old = file.samplefrequency(channel_to_extract) 81 | # Particular fix for mass dataset: 82 | fs_old_round = int(np.round(fs_old)) 83 | # Transform the original fs frequency with decimals to rounded version 84 | signal = utils.resample_signal_linear( 85 | signal, fs_old=fs_old, fs_new=fs_old_round 86 | ) 87 | signal = signal.astype(np.float32) 88 | return signal 89 | 90 | def _read_states_raw(self, path_states_file): 91 | """Loads hypnogram from 'path_states_file'.""" 92 | with pyedflib.EdfReader(path_states_file) as file: 93 | annotations = file.readAnnotations() 94 | onsets = np.array(annotations[0]) # In seconds 95 | durations = np.round(np.array(annotations[1])) # In seconds 96 | stages_str = annotations[2] 97 | # keep only 20s durations 98 | valid_idx = durations == self.page_duration 99 | onsets = onsets[valid_idx] 100 | stages_str = stages_str[valid_idx] 101 | stages_char = np.asarray([single_annot[-1] for single_annot in stages_str]) 102 | # Sort by onset 103 | sorted_locs = np.argsort(onsets) 104 | onsets = onsets[sorted_locs] 105 | stages_char = stages_char[sorted_locs] 106 | # The hypnogram could start at a sample different from 0 107 | start_time = onsets[0] 108 | onsets_relative = onsets - start_time 109 | onsets_pages = np.round(onsets_relative / self.page_duration).astype(np.int32) 110 | n_scored_pages = ( 111 | 1 + onsets_pages[-1] 112 | ) # might be greater than onsets_pages.size if some labels are missing 113 | start_sample = int(start_time * self.fs) 114 | hypnogram = (n_scored_pages + 1) * [ 115 | self.unknown_id 116 | ] # if missing, it will be "?", we add one final '?' 117 | for scored_pos, scored_label in zip(onsets_pages, stages_char): 118 | hypnogram[scored_pos] = scored_label 119 | hypnogram = np.asarray(hypnogram) 120 | return hypnogram, start_sample 121 | 122 | def _fix_signal_and_states(self, signal, hypnogram, start_sample): 123 | # Crop start of signal 124 | signal = signal[start_sample:] 125 | # Find the largest valid sample, common in both signal and hypnogram, with an integer number of pages 126 | n_samples_from_signal = int(self.page_size * (signal.size // self.page_size)) 127 | n_samples_from_hypnogram = int(hypnogram.size * self.page_size) 128 | n_samples_valid = min(n_samples_from_signal, n_samples_from_hypnogram) 129 | n_pages_valid = int(n_samples_valid / self.page_size) 130 | # Fix signal and hypnogram according to this maximum sample 131 | signal = signal[:n_samples_valid] 132 | hypnogram = hypnogram[:n_pages_valid] 133 | end_sample = ( 134 | start_sample + n_samples_valid 135 | ) # wrt original beginning of recording, useful for marks 136 | return signal, hypnogram, end_sample 137 | -------------------------------------------------------------------------------- /sleeprnn/data/moda_ss.py: -------------------------------------------------------------------------------- 1 | """mass_ss.py: Defines the MASS class that manipulates the MASS database.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import time 9 | 10 | import numpy as np 11 | 12 | from sleeprnn.common import constants 13 | from sleeprnn.data import utils 14 | from sleeprnn.data import stamp_correction 15 | from sleeprnn.data.dataset import Dataset 16 | from sleeprnn.data.dataset import KEY_EEG, KEY_MARKS 17 | from sleeprnn.data.dataset import KEY_N2_PAGES, KEY_ALL_PAGES, KEY_HYPNOGRAM 18 | 19 | PATH_MODA_RELATIVE = "moda" 20 | PATH_SEGMENTS = "segments/moda_preprocessed_segments.npz" 21 | 22 | KEY_N_BLOCKS = "n_blocks" 23 | KEY_PHASE = "phase" 24 | 25 | 26 | class ModaSS(Dataset): 27 | def __init__(self, params=None, load_checkpoint=False, verbose=True, **kwargs): 28 | """Constructor""" 29 | self.original_fs = 256 # Hz 30 | self.original_border_duration = 30 # s 31 | # Hypnogram parameters 32 | self.unknown_id = "?" # Character for unknown state in hypnogram 33 | self.n2_id = "2" # Character for N2 identification in hypnogram 34 | 35 | # Sleep spindles characteristics 36 | self.min_ss_duration = 0.3 # Minimum duration of SS in seconds 37 | self.max_ss_duration = 3 # Maximum duration of SS in seconds 38 | 39 | all_ids = self._get_ids() 40 | 41 | super(ModaSS, self).__init__( 42 | dataset_dir=PATH_MODA_RELATIVE, 43 | load_checkpoint=load_checkpoint, 44 | dataset_name=constants.MODA_SS_NAME, 45 | all_ids=all_ids, 46 | event_name=constants.SPINDLE, 47 | hypnogram_sleep_labels=["2"], 48 | hypnogram_page_duration=20, 49 | n_experts=1, 50 | params=params, 51 | verbose=verbose, 52 | ) 53 | 54 | self.global_std = None 55 | if verbose: 56 | print("Global STD:", self.global_std) 57 | 58 | def cv_split(self, n_folds, fold_id, seed=0, subject_ids=None): 59 | """Stratified 5-fold or 10-fold CV splits 60 | stratified in the sense of preserving distribution of phases and n_blocks. 61 | 62 | Inputs: 63 | n_folds: either 5 or 10 64 | fold_id: integer in [0, 1, ..., n_folds - 1] (which fold to retrieve) 65 | seed: random seed (determines the permutation of subjects before k-fold CV) 66 | """ 67 | if n_folds not in [5, 10]: 68 | raise ValueError("%d folds are not supported, choose 5 or 10" % n_folds) 69 | if fold_id >= n_folds: 70 | raise ValueError("fold id %s invalid for %d folds" % (fold_id, n_folds)) 71 | # Retrieve data 72 | subject_ids = np.asarray(self.all_ids.copy()) 73 | phases = np.asarray( 74 | [self.data[subject_id][KEY_PHASE] for subject_id in subject_ids] 75 | ) 76 | n_blocks = np.asarray( 77 | [self.data[subject_id][KEY_N_BLOCKS] for subject_id in subject_ids] 78 | ) 79 | # Groups of subjects 80 | subjects_p1_n10 = subject_ids[(phases == 1) & (n_blocks == 10)] 81 | subjects_p2_n10 = subject_ids[(phases == 2) & (n_blocks == 10)] 82 | subjects_p1_n3 = subject_ids[(phases == 1) & (n_blocks < 10)] 83 | subjects_p2_n3 = subject_ids[(phases == 2) & (n_blocks < 10)] 84 | # Random shuffle 85 | subjects_p1_n10 = np.random.RandomState(seed=seed).permutation(subjects_p1_n10) 86 | subjects_p2_n10 = np.random.RandomState(seed=seed).permutation(subjects_p2_n10) 87 | subjects_p1_n3 = np.random.RandomState(seed=seed).permutation(subjects_p1_n3) 88 | subjects_p2_n3 = np.random.RandomState(seed=seed).permutation(subjects_p2_n3) 89 | # Form folds 90 | test_folds = [] 91 | last_p1_n10 = 0 92 | last_p2_n10 = 0 93 | last_p1_n3 = 0 94 | last_p2_n3 = 0 95 | for i in range(n_folds): 96 | if i % 2 == 0: 97 | n_p1_n10 = int(np.floor(subjects_p1_n10.size / n_folds)) 98 | n_p2_n10 = int(np.ceil(subjects_p2_n10.size / n_folds)) 99 | n_p1_n3 = int(np.ceil(subjects_p1_n3.size / n_folds)) 100 | n_p2_n3 = int(np.floor(subjects_p2_n3.size / n_folds)) 101 | else: 102 | n_p1_n10 = int(np.ceil(subjects_p1_n10.size / n_folds)) 103 | n_p2_n10 = int(np.floor(subjects_p2_n10.size / n_folds)) 104 | n_p1_n3 = int(np.floor(subjects_p1_n3.size / n_folds)) 105 | n_p2_n3 = int(np.ceil(subjects_p2_n3.size / n_folds)) 106 | new_fold = [ 107 | subjects_p1_n10[last_p1_n10 : last_p1_n10 + n_p1_n10], 108 | subjects_p2_n10[last_p2_n10 : last_p2_n10 + n_p2_n10], 109 | subjects_p1_n3[last_p1_n3 : last_p1_n3 + n_p1_n3], 110 | subjects_p2_n3[last_p2_n3 : last_p2_n3 + n_p2_n3], 111 | ] 112 | new_fold = np.concatenate(new_fold) 113 | test_folds.append(new_fold) 114 | last_p1_n10 += n_p1_n10 115 | last_p2_n10 += n_p2_n10 116 | last_p1_n3 += n_p1_n3 117 | last_p2_n3 += n_p2_n3 118 | # Select split 119 | test_ids = test_folds[fold_id] 120 | val_ids = test_folds[(fold_id + 1) % n_folds] 121 | train_ids = [ 122 | s for s in subject_ids if s not in np.concatenate([val_ids, test_ids]) 123 | ] 124 | # Sort 125 | train_ids = np.sort(train_ids).tolist() 126 | val_ids = np.sort(val_ids).tolist() 127 | test_ids = np.sort(test_ids).tolist() 128 | return train_ids, val_ids, test_ids 129 | 130 | def _get_ids(self): 131 | fpath = self._get_data_path() 132 | dataset = np.load(fpath) 133 | subjects_of_segments = dataset["subjects"] 134 | subject_ids = np.unique(subjects_of_segments) 135 | subject_ids = subject_ids.tolist() 136 | subject_ids.sort() 137 | return subject_ids 138 | 139 | def _get_data_path(self): 140 | return os.path.abspath( 141 | os.path.join(utils.PATH_DATA, PATH_MODA_RELATIVE, PATH_SEGMENTS) 142 | ) 143 | 144 | def _load_from_source(self): 145 | """Loads the data from files and transforms it appropriately.""" 146 | fpath = self._get_data_path() 147 | dataset = np.load(fpath) 148 | data = {} 149 | n_subjects = len(self.all_ids) 150 | start = time.time() 151 | for i, subject_id in enumerate(self.all_ids): 152 | print("\nLoading ID %s" % subject_id) 153 | subject_locs = np.sort(np.where(dataset["subjects"] == subject_id)[0]) 154 | n_blocks = subject_locs.size 155 | phase = dataset["phases"][subject_locs[0]] 156 | signals = dataset["signals"][subject_locs, :] 157 | labels = dataset["labels"][subject_locs, :] 158 | signals = self._prepare_signals(signals) # [n_samples] 159 | labels = self._prepare_labels(labels) # [n_spindles, 2] 160 | n2_pages, hypnogram = self._generate_states(n_blocks) 161 | total_pages = hypnogram.size 162 | all_pages = np.arange(1, total_pages - 1, dtype=np.int16) 163 | print("N2 pages: %d" % n2_pages.shape[0]) 164 | print("Whole-night pages: %d" % all_pages.shape[0]) 165 | print("Hypnogram pages: %d" % hypnogram.shape[0]) 166 | print("Marks SS from E1: %d" % labels.shape[0]) 167 | # Save data 168 | ind_dict = { 169 | KEY_EEG: signals, 170 | KEY_N2_PAGES: n2_pages, 171 | KEY_ALL_PAGES: all_pages, 172 | "%s_1" % KEY_MARKS: labels, 173 | KEY_HYPNOGRAM: hypnogram, 174 | KEY_PHASE: phase, 175 | KEY_N_BLOCKS: n_blocks, 176 | } 177 | data[subject_id] = ind_dict 178 | print( 179 | "Loaded ID %s (%03d/%03d ready). Time elapsed: %1.4f [s]" 180 | % (subject_id, i + 1, n_subjects, time.time() - start) 181 | ) 182 | print("%d records have been read." % len(data)) 183 | return data 184 | 185 | def _prepare_signals(self, list_of_segments): 186 | # Each segment is of length 30s + 115s + 30s 187 | # We add 2.5s at each side of the blocks to make them of 120s = 6 * 20s 188 | # Therefore, each segment contributes with six 20s pages. 189 | # Additionally, we add 20s of border at each block to allow context. 190 | # Therefore, each segment contributes with two 20s pages of "?" state. 191 | # In summary, we need 20s + 120s + 20s = 22.5s + 115s + 22.5s 192 | target_border_duration = 22.5 193 | crop_size = int( 194 | self.original_fs * (self.original_border_duration - target_border_duration) 195 | ) 196 | list_of_segments = list_of_segments[:, crop_size:-crop_size] 197 | signal = np.concatenate(list_of_segments) 198 | # Now resample to the required frequency 199 | if self.fs != self.original_fs: 200 | print( 201 | "Resampling from %d Hz to required %d Hz" % (self.original_fs, self.fs) 202 | ) 203 | signal = utils.resample_signal( 204 | signal, fs_old=self.original_fs, fs_new=self.fs 205 | ) 206 | else: 207 | print("Signal already at required %d Hz" % self.fs) 208 | signal = signal.astype(np.float32) 209 | return signal 210 | 211 | def _prepare_labels(self, list_of_labels): 212 | target_border_duration = 22.5 213 | crop_size = int( 214 | self.original_fs * (self.original_border_duration - target_border_duration) 215 | ) 216 | list_of_labels = list_of_labels[:, crop_size:-crop_size] 217 | labels = np.concatenate(list_of_labels) 218 | binary_labels = np.clip( 219 | labels, a_min=0, a_max=1 220 | ) # borders will have a label of zero 221 | marks = utils.seq2stamp(binary_labels) 222 | marks_time = marks / self.original_fs # sample to seconds 223 | # Transforms to sample-stamps 224 | marks = np.round(marks_time * self.fs).astype( 225 | np.int32 226 | ) # second to samples in target fs 227 | # Combine marks that are too close according to standards 228 | marks = stamp_correction.combine_close_stamps( 229 | marks, self.fs, self.min_ss_duration 230 | ) 231 | # Fix durations that are outside standards 232 | marks = stamp_correction.filter_duration_stamps( 233 | marks, self.fs, self.min_ss_duration, self.max_ss_duration 234 | ) 235 | return marks 236 | 237 | def _generate_states(self, n_blocks): 238 | # Each block has ?, 6x N2, and ? 239 | single_block = [self.unknown_id] + 6 * [self.n2_id] + [self.unknown_id] 240 | hypnogram = n_blocks * single_block 241 | hypnogram = np.asarray(hypnogram) 242 | # Extract N2 pages 243 | n2_pages = np.where(hypnogram == self.n2_id)[0] 244 | n2_pages = n2_pages.astype(np.int16) 245 | return n2_pages, hypnogram 246 | -------------------------------------------------------------------------------- /sleeprnn/data/pink.py: -------------------------------------------------------------------------------- 1 | """@Author: Nicolas I. Tapia-Rivas""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import time 9 | 10 | import numpy as np 11 | from scipy.interpolate import interp1d 12 | 13 | from sleeprnn.common import constants 14 | from sleeprnn.data.dataset import Dataset 15 | from sleeprnn.data.dataset import KEY_EEG, KEY_MARKS 16 | from sleeprnn.data.dataset import KEY_N2_PAGES, KEY_ALL_PAGES, KEY_HYPNOGRAM 17 | from sleeprnn.data import utils 18 | 19 | PATH_PINK_RELATIVE = "pink" 20 | 21 | 22 | class Pink(Dataset): 23 | def __init__(self, params=None, load_checkpoint=False, verbose=True, **kwargs): 24 | self.channel = "artificial" 25 | self.n_signals = 25 26 | self.n2_id = "2" 27 | self.unknown_id = "?" 28 | # Generation parameters 29 | self.signal_duration = ( 30 | 3600 + 2 * 20 31 | ) # 1 hour of useful signal + 1 page at borders 32 | self.power_matching_highcut = 8 # [Hz] 33 | self.power_matching_target_value = 0.7286483227138594 34 | self.spectrum_profile_fn = self._get_profile_fn() 35 | 36 | all_ids = np.arange(1, self.n_signals + 1).tolist() 37 | super(Pink, self).__init__( 38 | dataset_dir=PATH_PINK_RELATIVE, 39 | load_checkpoint=load_checkpoint, 40 | dataset_name=constants.PINK_NAME, 41 | all_ids=all_ids, 42 | event_name="none", 43 | hypnogram_sleep_labels=["2"], 44 | hypnogram_page_duration=[20], 45 | params=params, 46 | verbose=verbose, 47 | ) 48 | self.global_std = None 49 | if verbose: 50 | print("Global STD", self.global_std) 51 | 52 | self.filt_signal_dict = {} 53 | self.exists_cache = False 54 | 55 | def _load_from_source(self): 56 | n_pages = self.signal_duration // self.page_duration 57 | data = {} 58 | start = time.time() 59 | for i, subject_id in enumerate(self.all_ids): 60 | print("\nGenerating pink noise ID %s" % subject_id) 61 | signal = self._generate_signal(subject_id) 62 | hypnogram = ( 63 | [self.unknown_id] + (n_pages - 2) * [self.n2_id] + [self.unknown_id] 64 | ) 65 | hypnogram = np.asarray(hypnogram) 66 | n2_pages = np.where(hypnogram == self.n2_id)[0].astype(np.int16) 67 | all_pages = np.arange(1, n_pages - 1, dtype=np.int16) 68 | marks = np.zeros(shape=(0, 2)).astype(np.int32) 69 | print("N2 pages: %d" % n2_pages.shape[0]) 70 | print("Whole-night pages: %d" % all_pages.shape[0]) 71 | print("Marks SS from E1: %d" % marks.shape[0]) 72 | # Save data 73 | ind_dict = { 74 | KEY_EEG: signal, 75 | KEY_N2_PAGES: n2_pages, 76 | KEY_ALL_PAGES: all_pages, 77 | KEY_HYPNOGRAM: hypnogram, 78 | "%s_1" % KEY_MARKS: marks, 79 | } 80 | data[subject_id] = ind_dict 81 | print( 82 | "Loaded ID %d (%02d/%02d ready). Time elapsed: %1.4f [s]" 83 | % (subject_id, i + 1, self.n_signals, time.time() - start) 84 | ) 85 | print("%d records have been read." % len(data)) 86 | return data 87 | 88 | def _get_profile_fn(self): 89 | pink_profile = np.load( 90 | os.path.join(utils.PATH_DATA, PATH_PINK_RELATIVE, "pink_profile.npy") 91 | ) 92 | profile_fn = interp1d(pink_profile[0], pink_profile[1]) 93 | return profile_fn 94 | 95 | def _generate_signal(self, seed): 96 | # Base noise 97 | n_samples = int(self.signal_duration * self.fs) 98 | x = np.random.RandomState(seed=seed).normal(size=n_samples) 99 | # Scale the FFT spectrum 100 | y = np.fft.rfft(x) 101 | freq_gen = np.fft.rfftfreq(x.size, d=1.0 / self.fs) 102 | scaling = self.spectrum_profile_fn(freq_gen) 103 | y = y * scaling 104 | # Return to time domain and normalize 105 | x = np.fft.irfft(y) 106 | x = x - x.mean() 107 | x = x / x.std() 108 | # Filter to desired band 109 | x = utils.broad_filter(x, self.fs) 110 | # Scale to target amplitude 111 | f, p = utils.power_spectrum_by_sliding_window(x, self.fs) 112 | power_in_band = p[f <= self.power_matching_highcut].mean() 113 | correction_factor = self.power_matching_target_value / power_in_band 114 | x = x * correction_factor 115 | # Cast to desired type 116 | x = x.astype(np.float32) 117 | return x 118 | 119 | def create_signal_cache(self, highcut=4): 120 | signals = self.get_signals(normalize_clip=False) 121 | for k, sub_id in enumerate(self.all_ids): 122 | filt_signal = utils.filter_iir_lowpass(signals[k], self.fs, highcut=highcut) 123 | filt_signal = filt_signal.astype(np.float32) 124 | self.filt_signal_dict[sub_id] = filt_signal 125 | self.exists_cache = True 126 | 127 | def delete_signal_cache(self): 128 | self.filt_signal_dict = {} 129 | self.exists_cache = False 130 | 131 | def get_subject_filt_signal(self, subject_id): 132 | if self.exists_cache: 133 | signal = self.filt_signal_dict[subject_id] 134 | else: 135 | signal = None 136 | return signal 137 | 138 | def get_subset_filt_signals(self, subject_id_list): 139 | if self.exists_cache: 140 | subset_signals = [ 141 | self.get_subject_filt_signal(sub_id) for sub_id in subject_id_list 142 | ] 143 | else: 144 | subset_signals = None 145 | return subset_signals 146 | 147 | def get_filt_signals(self): 148 | if self.exists_cache: 149 | subset_signals = [ 150 | self.get_subject_filt_signal(sub_id) for sub_id in self.all_ids 151 | ] 152 | else: 153 | subset_signals = None 154 | return subset_signals 155 | 156 | def exists_filt_signal_cache(self): 157 | return self.exists_cache 158 | -------------------------------------------------------------------------------- /sleeprnn/data/stamp_correction.py: -------------------------------------------------------------------------------- 1 | """stamp_correction.py: Module for general postprocessing operations of 2 | annotations.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import numpy as np 9 | 10 | 11 | def combine_close_stamps(marks, fs, min_separation): 12 | """Combines contiguous marks that are too close to each other. Marks are 13 | assumed to be sample-stamps. 14 | 15 | If min_separation is None, the functionality is bypassed. 16 | """ 17 | if marks.size == 0: 18 | return marks 19 | 20 | if min_separation is None: 21 | combined_marks = marks 22 | else: 23 | marks = np.sort(marks, axis=0) 24 | combined_marks = [marks[0, :]] 25 | for i in range(1, marks.shape[0]): 26 | last_mark = combined_marks[-1] 27 | this_mark = marks[i, :] 28 | gap = (this_mark[0] - last_mark[1]) / fs 29 | if gap < min_separation: 30 | # Combine mark, so the last mark ends in the maximum ending. 31 | combined_marks[-1][1] = max(this_mark[1], combined_marks[-1][1]) 32 | else: 33 | combined_marks.append(this_mark) 34 | combined_marks = np.stack(combined_marks, axis=0) 35 | return combined_marks 36 | 37 | 38 | def filter_duration_stamps(marks, fs, min_duration, max_duration, repair_long=True): 39 | """Removes marks that are too short or strangely long. Marks longer than 40 | max_duration but not strangely long are cropped to keep the central 41 | max_duration duration. Durations are assumed to be in seconds. 42 | Marks are assumed to be sample-stamps. 43 | 44 | If min_duration is None, no short marks are removed. 45 | If max_duration is None, no long marks are removed. 46 | """ 47 | if marks.size == 0: 48 | return marks 49 | 50 | if min_duration is None and max_duration is None: 51 | return marks 52 | else: 53 | durations = (marks[:, 1] - marks[:, 0] + 1) / fs 54 | 55 | if min_duration is not None: 56 | # Remove too short spindles 57 | feasible_idx = np.where(durations >= min_duration)[0] 58 | marks = marks[feasible_idx, :] 59 | durations = durations[feasible_idx] 60 | 61 | if max_duration is not None: 62 | 63 | if repair_long: 64 | # Remove weird annotations (extremely long) 65 | feasible_idx = np.where(durations <= 2 * max_duration)[0] 66 | marks = marks[feasible_idx, :] 67 | durations = durations[feasible_idx] 68 | 69 | # For annotations with durations longer than max_duration, 70 | # keep the central seconds 71 | excess = durations - max_duration 72 | excess = np.clip(excess, 0, None) 73 | half_remove = ((fs * excess + 1) / 2).astype(np.int32) 74 | half_remove_array = np.stack([half_remove, -half_remove], axis=1) 75 | marks = marks + half_remove_array 76 | # marks[:, 0] = marks[:, 0] + half_remove 77 | # marks[:, 1] = marks[:, 1] - half_remove 78 | else: 79 | # No repairing, simply remove 80 | feasible_idx = np.where(durations <= max_duration)[0] 81 | marks = marks[feasible_idx, :] 82 | return marks 83 | -------------------------------------------------------------------------------- /sleeprnn/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolasigor/Sleep-EEG-Event-Detector/24322278e3f3ef7535413a65a61fffc3ce4f4e01/sleeprnn/detection/__init__.py -------------------------------------------------------------------------------- /sleeprnn/detection/det_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sleeprnn.data import utils 3 | 4 | 5 | def transform_predicted_proba_to_adjusted_proba( 6 | predicted_proba, optimal_threshold, eps=1e-8 7 | ): 8 | """ 9 | Adjusts probability vector so that 10 | adjusted_proba > 0.5 is equivalent to predicted_proba > optimal_threshold 11 | 12 | :param predicted_proba: vector of predicted probabilities. 13 | :param optimal_threshold: optimal threshold for class assignment in predicted probabilities. 14 | :param eps: for numerical stability. Defaults to 1e-8. 15 | :return: the vector of adjusted probabilities. 16 | """ 17 | 18 | # Edge cases: 19 | if optimal_threshold == 0: 20 | # Then everything is above or at the threshold 21 | # We simulate this by simply mapping 0 - 1 to 0.5 - 1 22 | adjusted_proba = 0.5 * predicted_proba + 0.5 23 | elif optimal_threshold == 1: 24 | # Then everything is below or at the threshold 25 | # We simulate this by simply mapping 0-1 to 0-0.5 26 | adjusted_proba = 0.5 * predicted_proba 27 | else: 28 | # Prepare 29 | original_dtype = predicted_proba.dtype 30 | predicted_proba = predicted_proba.astype(np.float64) 31 | predicted_proba = np.clip(predicted_proba, a_min=eps, a_max=(1.0 - eps)) 32 | # Compute 33 | logit_proba = np.log(predicted_proba / (1.0 - predicted_proba)) 34 | bias_from_thr = -np.log(optimal_threshold / (1.0 - optimal_threshold)) 35 | new_logit_proba = logit_proba + bias_from_thr 36 | adjusted_proba = 1.0 / (1.0 + np.exp(-new_logit_proba)) 37 | # Go back to original dtype 38 | adjusted_proba = adjusted_proba.astype(original_dtype) 39 | return adjusted_proba 40 | 41 | 42 | def transform_thr_for_adjusted_to_thr_for_predicted( 43 | thr_for_adjusted, optimal_threshold 44 | ): 45 | """ 46 | Returns a threshold that can be applied to the predicted probabilities so that 47 | predicted_proba > thr_for_predicted is equivalent to adjusted_proba > thr_for_adjusted 48 | 49 | :param thr_for_adjusted: threshold for class assignment in adjusted probabilities. 50 | :param optimal_threshold: optimal threshold for class assignment in predicted probabilities. 51 | :return: the equivalent threshold for class assignment in predicted probabilities 52 | """ 53 | num = thr_for_adjusted * optimal_threshold 54 | den = thr_for_adjusted * optimal_threshold + (1.0 - thr_for_adjusted) * ( 55 | 1.0 - optimal_threshold 56 | ) 57 | thr_for_predicted = num / den 58 | return thr_for_predicted 59 | 60 | 61 | def get_event_probabilities(marks, probability, downsampling_factor=8, proba_prc=75): 62 | probability_upsampled = np.repeat(probability, downsampling_factor) 63 | # Retrieve segments of probabilities 64 | marks_segments = [probability_upsampled[m[0] : (m[1] + 1)] for m in marks] 65 | marks_proba = [np.percentile(m_seg, proba_prc) for m_seg in marks_segments] 66 | marks_proba = np.array(marks_proba) 67 | return marks_proba 68 | -------------------------------------------------------------------------------- /sleeprnn/detection/ensemble.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sleeprnn.data import utils 3 | from sleeprnn.detection.predicted_dataset import PredictedDataset 4 | 5 | 6 | def generate_ensemble_from_probabilities( 7 | dict_of_proba, reference_feeder_dataset, skip_setting_threshold=False 8 | ): 9 | """ 10 | dict_of_proba = { 11 | subject_id_1: list of probabilities to ensemble, 12 | subject_id_2: list of probabilities to ensemble, 13 | etc 14 | } 15 | """ 16 | subject_ids = reference_feeder_dataset.get_ids() 17 | avg_dict = {} 18 | for subject_id in subject_ids: 19 | probabilities = ( 20 | np.stack(dict_of_proba[subject_id], axis=0) 21 | .astype(np.float32) 22 | .mean(axis=0) 23 | .astype(np.float16) 24 | ) 25 | avg_dict[subject_id] = probabilities 26 | ensemble_prediction = PredictedDataset( 27 | dataset=reference_feeder_dataset, 28 | probabilities_dict=avg_dict, 29 | params=reference_feeder_dataset.params.copy(), 30 | skip_setting_threshold=skip_setting_threshold, 31 | ) 32 | return ensemble_prediction 33 | 34 | 35 | def generate_ensemble_from_stamps( 36 | dict_of_stamps, 37 | reference_feeder_dataset, 38 | downsampling_factor=8, 39 | skip_setting_threshold=False, 40 | ): 41 | """ 42 | dict_of_stamps = { 43 | subject_id_1: list of stamps to ensemble, 44 | subject_id_2: list of stamps to ensemble, 45 | etc 46 | } 47 | """ 48 | subject_ids = reference_feeder_dataset.get_ids() 49 | dict_of_proba = {} 50 | for subject_id in subject_ids: 51 | stamps_list = dict_of_stamps[subject_id] 52 | subject_max_sample = np.max( 53 | [ 54 | (1 if single_stamp.size == 0 else single_stamp.max()) 55 | for single_stamp in stamps_list 56 | ] 57 | ) 58 | subject_max_sample = downsampling_factor * ( 59 | (subject_max_sample // downsampling_factor) + 10 60 | ) 61 | probabilities = [ 62 | utils.stamp2seq(single_stamp, 0, subject_max_sample - 1) 63 | .reshape(-1, downsampling_factor) 64 | .mean(axis=1) 65 | for single_stamp in stamps_list 66 | ] 67 | dict_of_proba[subject_id] = probabilities 68 | ensemble_prediction = generate_ensemble_from_probabilities( 69 | dict_of_proba, 70 | reference_feeder_dataset, 71 | skip_setting_threshold=skip_setting_threshold, 72 | ) 73 | return ensemble_prediction 74 | 75 | 76 | def generate_ensemble_from_predicted_datasets( 77 | predicted_dataset_list, 78 | reference_feeder_dataset, 79 | use_probabilities=False, 80 | skip_setting_threshold=False, 81 | ): 82 | subject_ids = reference_feeder_dataset.get_ids() 83 | dict_of_data = {} 84 | for subject_id in subject_ids: 85 | if use_probabilities: 86 | data_list = [ 87 | pred.get_subject_probabilities(subject_id, return_adjusted=True) 88 | for pred in predicted_dataset_list 89 | ] 90 | else: 91 | data_list = [ 92 | pred.get_subject_stamps(subject_id) for pred in predicted_dataset_list 93 | ] 94 | dict_of_data[subject_id] = data_list 95 | if use_probabilities: 96 | ensemble_prediction = generate_ensemble_from_probabilities( 97 | dict_of_data, 98 | reference_feeder_dataset, 99 | skip_setting_threshold=skip_setting_threshold, 100 | ) 101 | else: 102 | ensemble_prediction = generate_ensemble_from_stamps( 103 | dict_of_data, 104 | reference_feeder_dataset, 105 | skip_setting_threshold=skip_setting_threshold, 106 | ) 107 | return ensemble_prediction 108 | -------------------------------------------------------------------------------- /sleeprnn/detection/feeder_dataset.py: -------------------------------------------------------------------------------- 1 | """mass_ss.py: Defines the MASS class that manipulates the MASS database.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | 9 | from sleeprnn.data.dataset import Dataset 10 | from sleeprnn.data.dataset import KEY_EEG, KEY_MARKS, KEY_N2_PAGES, KEY_ALL_PAGES 11 | from sleeprnn.data import utils 12 | from sleeprnn.common import constants, checks 13 | 14 | 15 | class FeederDataset(Dataset): 16 | 17 | def __init__( 18 | self, 19 | dataset: Dataset, 20 | sub_ids, 21 | task_mode=constants.N2_RECORD, 22 | which_expert=1, 23 | verbose=False, 24 | n2_subsampling_factor=1.0, 25 | ): 26 | """Constructor""" 27 | checks.check_valid_value( 28 | task_mode, "task_mode", [constants.WN_RECORD, constants.N2_RECORD] 29 | ) 30 | 31 | self.parent_dataset = dataset 32 | self.parent_dataset_class = dataset.__class__ 33 | self.task_mode = task_mode 34 | self.which_expert = which_expert 35 | self.n2_subsampling_factor = n2_subsampling_factor 36 | 37 | super(FeederDataset, self).__init__( 38 | dataset_dir=dataset.dataset_dir, 39 | load_checkpoint=False, 40 | dataset_name="%s_subset" % dataset.dataset_name, 41 | all_ids=sub_ids, 42 | event_name=dataset.event_name, 43 | hypnogram_sleep_labels=dataset.hypnogram_sleep_labels, 44 | hypnogram_page_duration=dataset.hypnogram_page_duration, 45 | n_experts=dataset.n_experts, 46 | default_expert=which_expert, 47 | default_page_subset=task_mode, 48 | params=dataset.params.copy(), 49 | verbose=verbose, 50 | ) 51 | self.global_std = dataset.global_std 52 | 53 | def read_subject_data(self, subject_id): 54 | # Original data 55 | ind_dict = self.parent_dataset_class.read_subject_data(self, subject_id) 56 | # Return a subsample if required 57 | if self.n2_subsampling_factor < 1: 58 | n2_pages = ind_dict[KEY_N2_PAGES] 59 | 60 | # Random permutation with a fixed seed so that we can sample N2 pages from anywhere 61 | # print("Shuffling N2 pages before subsampling") 62 | n2_pages = np.random.RandomState(seed=0).permutation(n2_pages) 63 | 64 | n_pages_to_extract = int( 65 | np.ceil(self.n2_subsampling_factor * n2_pages.size) 66 | ) 67 | ind_dict[KEY_N2_PAGES] = n2_pages[:n_pages_to_extract] # contiguous segment 68 | return ind_dict 69 | 70 | def _load_from_source(self): 71 | """Loads the data from source.""" 72 | data = self.parent_dataset.get_sub_dataset(self.all_ids) 73 | self.parent_dataset = None 74 | return data 75 | 76 | def get_data_for_training( 77 | self, 78 | border_size=0, 79 | forced_mark_separation_size=0, 80 | return_page_mask=False, 81 | verbose=False, 82 | ): 83 | output = super().get_data( 84 | augmented_page=True, 85 | border_size=border_size, 86 | forced_mark_separation_size=forced_mark_separation_size, 87 | which_expert=self.which_expert, 88 | pages_subset=self.task_mode, 89 | normalize_clip=True, 90 | normalization_mode=self.task_mode, 91 | return_page_mask=return_page_mask, 92 | verbose=verbose, 93 | ) 94 | return output 95 | 96 | def get_data_for_prediction( 97 | self, 98 | border_size=0, 99 | predict_with_augmented_page=True, 100 | return_page_mask=False, 101 | verbose=False, 102 | ): 103 | output = super().get_data( 104 | augmented_page=predict_with_augmented_page, 105 | border_size=border_size, 106 | which_expert=self.which_expert, 107 | pages_subset=constants.WN_RECORD, 108 | normalize_clip=True, 109 | normalization_mode=self.task_mode, 110 | return_page_mask=return_page_mask, 111 | verbose=verbose, 112 | ) 113 | return output 114 | 115 | def get_data_for_stats(self, border_size=0, verbose=False): 116 | subset_signals = [] 117 | for subject_id in self.all_ids: 118 | signal = self.get_subject_data_for_stats( 119 | subject_id=subject_id, border_size=border_size, verbose=verbose 120 | ) 121 | subset_signals.append(signal) 122 | return subset_signals 123 | 124 | def get_subject_data_for_stats(self, subject_id, border_size=0, verbose=False): 125 | checks.check_valid_value(subject_id, "ID", self.all_ids) 126 | ind_dict = self.read_subject_data(subject_id) 127 | 128 | # Unpack data 129 | signal = ind_dict[KEY_EEG] 130 | marks = ind_dict["%s_%d" % (KEY_MARKS, self.which_expert)] 131 | # Transform stamps into sequence 132 | marks = utils.stamp2seq(marks, 0, signal.shape[0] - 1) 133 | 134 | if self.task_mode == constants.WN_RECORD: 135 | if verbose: 136 | print("Stats from pages containing true events.") 137 | # Normalize using stats from pages with true events. 138 | stat_pages = ind_dict[KEY_ALL_PAGES] 139 | activity = utils.extract_pages( 140 | marks, stat_pages, self.page_size, border_size=0 141 | ) 142 | activity = activity.sum(axis=1) 143 | activity = np.where(activity > 0)[0] 144 | stat_pages = stat_pages[activity] 145 | else: 146 | if verbose: 147 | print("Stats from N2 pages.") 148 | stat_pages = ind_dict[KEY_N2_PAGES] 149 | signal, _ = utils.norm_clip_signal(signal, stat_pages, self.page_size) 150 | # Extract segments 151 | signal = utils.extract_pages( 152 | signal, stat_pages, self.page_size, border_size=border_size 153 | ) 154 | return signal 155 | -------------------------------------------------------------------------------- /sleeprnn/detection/postprocessing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from scipy.signal import find_peaks 7 | 8 | from sleeprnn.data.utils import filter_iir_lowpass, apply_bandpass 9 | 10 | 11 | def kcomplex_stamp_split( 12 | signal, 13 | stamps, 14 | fs, 15 | highcut=4, 16 | left_edge_tol=0.05, 17 | right_edge_tol=0.2, 18 | signal_is_filtered=False, 19 | ): 20 | left_edge_tol = fs * left_edge_tol 21 | right_edge_tol = fs * right_edge_tol 22 | 23 | if signal_is_filtered: 24 | filt_signal = signal 25 | else: 26 | filt_signal = filter_iir_lowpass(signal, fs, highcut=highcut) 27 | 28 | new_stamps = [] 29 | for stamp in stamps: 30 | stamp_size = stamp[1] - stamp[0] + 1 31 | filt_in_stamp = filt_signal[stamp[0] : stamp[1]] 32 | negative_peaks, _ = find_peaks(-filt_in_stamp) 33 | # peaks needs to be negative 34 | negative_peaks = [peak for peak in negative_peaks if filt_in_stamp[peak] < 0] 35 | 36 | negative_peaks = [ 37 | peak 38 | for peak in negative_peaks 39 | if left_edge_tol < peak < stamp_size - right_edge_tol 40 | ] 41 | 42 | n_peaks = len(negative_peaks) 43 | if n_peaks > 1: 44 | # Change of sign filtering 45 | group_peaks = [[negative_peaks[0]]] 46 | idx_group = 0 47 | for i in range(1, len(negative_peaks)): 48 | last_peak = group_peaks[idx_group][-1] 49 | this_peak = negative_peaks[i] 50 | signal_between_peaks = filt_in_stamp[last_peak:this_peak] 51 | min_value = signal_between_peaks.min() 52 | max_value = signal_between_peaks.max() 53 | if min_value < 0 < max_value: 54 | # there is a change of sign, so it is a new group 55 | group_peaks.append([this_peak]) 56 | idx_group = idx_group + 1 57 | else: 58 | # Now change of sign, same group 59 | group_peaks[idx_group].append(this_peak) 60 | new_peaks = [] 61 | for single_group in group_peaks: 62 | new_peaks.append(int(np.mean(single_group))) 63 | negative_peaks = new_peaks 64 | 65 | n_peaks = len(negative_peaks) 66 | if n_peaks > 1: 67 | # Split marks 68 | edges_list = [stamp[0]] 69 | for i in range(n_peaks - 1): 70 | split_point_rel = (negative_peaks[i] + negative_peaks[i + 1]) / 2 71 | split_point_abs = int(stamp[0] + split_point_rel) 72 | edges_list.append(split_point_abs) 73 | edges_list.append(stamp[1]) 74 | for i in range(len(edges_list) - 1): 75 | new_stamps.append([edges_list[i], edges_list[i + 1]]) 76 | else: 77 | new_stamps.append(stamp) 78 | if len(new_stamps) > 0: 79 | new_stamps = np.stack(new_stamps, axis=0).astype(np.int32) 80 | else: 81 | new_stamps = np.zeros((0, 2)).astype(np.int32) 82 | return new_stamps 83 | 84 | 85 | def get_amplitude_spindle(x, fs, distance_in_seconds=0.04): 86 | no_peaks_found = False 87 | 88 | distance = int(fs * distance_in_seconds) 89 | peaks_max, _ = find_peaks(x, distance=distance) 90 | peaks_min, _ = find_peaks(-x, distance=distance) 91 | if len(peaks_max) == 0 or len(peaks_min) == 0: 92 | print("Second attempt to find peaks") 93 | # First try to fix 94 | distance = distance // 2 95 | peaks_max, _ = find_peaks(x, distance=distance) 96 | peaks_min, _ = find_peaks(-x, distance=distance) 97 | if len(peaks_max) == 0 or len(peaks_min) == 0: 98 | print("Third attempt to find peaks") 99 | # Second try to fix 100 | distance = distance // 2 101 | peaks_max, _ = find_peaks(x, distance=distance) 102 | peaks_min, _ = find_peaks(-x, distance=distance) 103 | if len(peaks_max) == 0 or len(peaks_min) == 0: 104 | print( 105 | "SKIPPED: Segment without peaks. Found %d peaks max and %d peaks min" 106 | % (len(peaks_max), len(peaks_min)) 107 | ) 108 | no_peaks_found = True 109 | if no_peaks_found: 110 | max_pp = 1e6 111 | else: 112 | peaks = np.sort(np.concatenate([peaks_max, peaks_min])) 113 | peak_values = x[peaks] 114 | peak_to_peak_diff = np.abs(np.diff(peak_values)) 115 | max_pp = np.max(peak_to_peak_diff) 116 | return max_pp 117 | 118 | 119 | def spindle_amplitude_filtering( 120 | signal, stamps, fs, max_amplitude, lowcut=9.5, highcut=16.5 121 | ): 122 | filt_signal = apply_bandpass(signal, fs, lowcut=lowcut, highcut=highcut) 123 | signal_events = [filt_signal[e[0] : e[1] + 1] for e in stamps] 124 | 125 | amplitudes = [] 126 | for i in range(len(signal_events)): 127 | s = signal_events[i] 128 | amp = get_amplitude_spindle(s, fs) 129 | if amp > 1e5: 130 | print("Anomaly mark is", stamps[i]) 131 | amplitudes.append(amp) 132 | amplitudes = np.array(amplitudes) 133 | 134 | if np.any(amplitudes > 1e5): 135 | no_peaks_found = True 136 | else: 137 | no_peaks_found = False 138 | 139 | valid_locs = np.where(amplitudes <= max_amplitude)[0] 140 | return stamps[valid_locs], no_peaks_found 141 | -------------------------------------------------------------------------------- /sleeprnn/detection/postprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from joblib import delayed, Parallel 6 | 7 | import numpy as np 8 | 9 | from sleeprnn.common import checks, constants, pkeys 10 | from sleeprnn.data.stamp_correction import filter_duration_stamps 11 | from sleeprnn.data.stamp_correction import combine_close_stamps 12 | from sleeprnn.data.utils import seq2stamp_with_pages, extract_pages_for_stamps 13 | from sleeprnn.data.utils import seq2stamp 14 | from sleeprnn.data.utils import get_overlap_matrix 15 | 16 | 17 | class PostProcessor(object): 18 | 19 | def __init__(self, event_name, params=None): 20 | checks.check_valid_value( 21 | event_name, "event_name", [constants.SPINDLE, constants.KCOMPLEX] 22 | ) 23 | 24 | self.event_name = event_name 25 | self.params = pkeys.default_params.copy() 26 | if params is not None: 27 | self.params.update(params) 28 | 29 | def proba2stamps( 30 | self, proba_data, pages_indices=None, pages_indices_subset=None, thr=0.5 31 | ): 32 | """ 33 | If thr is None, pages_sequence is assumed to be already binarized. 34 | fs_input corresponds to sampling frequency of pages_sequence, 35 | fs_outputs corresponds to desired sampling frequency. 36 | """ 37 | 38 | # Thresholding 39 | if thr is None: 40 | # We assume that sequence is already binary 41 | proba_data_bin_high = proba_data 42 | proba_data_bin_low = proba_data 43 | else: 44 | low_thr_factor = 0.85 45 | low_thr = thr * low_thr_factor 46 | # print("debug: low thr:", low_thr) 47 | proba_data_bin_high = (proba_data >= thr).astype(np.int32) 48 | proba_data_bin_low = (proba_data >= low_thr).astype(np.int32) 49 | 50 | # Transformation to stamps based on low thr (for duration) 51 | if pages_indices is None: 52 | stamps_low = seq2stamp(proba_data_bin_low) 53 | stamps_high = seq2stamp(proba_data_bin_high) 54 | else: 55 | stamps_low = seq2stamp_with_pages(proba_data_bin_low, pages_indices) 56 | stamps_high = seq2stamp_with_pages(proba_data_bin_high, pages_indices) 57 | 58 | # Only keep candidates that surpassed high threshold (for detection) 59 | # i.e., only stamps_low intersecting with stamps_high 60 | overlap_check = get_overlap_matrix( 61 | stamps_low, stamps_high 62 | ) # shape (n_low, n_high) 63 | if overlap_check.sum() == 0: 64 | stamps = np.zeros((0, 2), dtype=np.int32) 65 | else: 66 | overlap_check = overlap_check.sum(axis=1) # shape (n_low,) 67 | valid_lows = np.where(overlap_check > 0)[0] 68 | stamps = stamps_low[valid_lows] 69 | # print("debug: stamps low", stamps_low.shape, "stamps high", stamps_high.shape, "stamps", stamps.shape) 70 | 71 | # Postprocessing 72 | # Note that when min_separation, min_duration, or max_duration is None, 73 | # that postprocessing doesn't happen. 74 | downsampling_factor = self.params[pkeys.TOTAL_DOWNSAMPLING_FACTOR] 75 | fs_input = self.params[pkeys.FS] // downsampling_factor 76 | fs_output = self.params[pkeys.FS] 77 | 78 | if self.event_name == constants.SPINDLE: 79 | min_separation = self.params[pkeys.SS_MIN_SEPARATION] 80 | min_duration = self.params[pkeys.SS_MIN_DURATION] 81 | max_duration = self.params[pkeys.SS_MAX_DURATION] 82 | else: 83 | min_separation = self.params[pkeys.KC_MIN_SEPARATION] 84 | min_duration = self.params[pkeys.KC_MIN_DURATION] 85 | max_duration = self.params[pkeys.KC_MAX_DURATION] 86 | 87 | if pkeys.REPAIR_LONG_DETECTIONS not in self.params: 88 | repair_long = True # default 89 | else: 90 | repair_long = self.params[pkeys.REPAIR_LONG_DETECTIONS] 91 | 92 | stamps = combine_close_stamps(stamps, fs_input, min_separation) 93 | stamps = filter_duration_stamps( 94 | stamps, fs_input, min_duration, max_duration, repair_long=repair_long 95 | ) 96 | 97 | # Upsampling 98 | if fs_output > fs_input: 99 | stamps = self._upsample_stamps(stamps) 100 | elif fs_output < fs_input: 101 | raise ValueError("fs_output has to be greater than fs_input") 102 | 103 | if pages_indices_subset is not None: 104 | page_size = int(self.params[pkeys.PAGE_DURATION] * fs_output) 105 | stamps = extract_pages_for_stamps(stamps, pages_indices_subset, page_size) 106 | 107 | return stamps 108 | 109 | def proba2stamps_with_list( 110 | self, 111 | pages_sequence_list, 112 | pages_indices_list=None, 113 | pages_indices_subset_list=None, 114 | thr=0.5, 115 | ): 116 | 117 | if pages_indices_list is None: 118 | pages_indices_list = [None] * len(pages_sequence_list) 119 | if pages_indices_subset_list is None: 120 | pages_indices_subset_list = [None] * len(pages_sequence_list) 121 | 122 | stamps_list = Parallel(n_jobs=-1)( 123 | delayed(self.proba2stamps)( 124 | pages_sequence, 125 | pages_indices, 126 | pages_indices_subset=pages_indices_subset, 127 | thr=thr, 128 | ) 129 | for (pages_sequence, pages_indices, pages_indices_subset) in zip( 130 | pages_sequence_list, pages_indices_list, pages_indices_subset_list 131 | ) 132 | ) 133 | 134 | return stamps_list 135 | 136 | def _upsample_stamps(self, stamps): 137 | """Upsamples timestamps of stamps to match a greater sampling frequency.""" 138 | upsample_factor = self.params[pkeys.TOTAL_DOWNSAMPLING_FACTOR] 139 | if pkeys.ALIGNED_DOWNSAMPLING not in self.params: 140 | aligned_down = False 141 | else: 142 | aligned_down = self.params[pkeys.ALIGNED_DOWNSAMPLING] 143 | if aligned_down: 144 | stamps = stamps * upsample_factor 145 | stamps[:, 1] = stamps[:, 1] + upsample_factor - 1 146 | stamps = stamps.astype(np.int32) 147 | else: 148 | stamps = stamps * upsample_factor 149 | stamps[:, 0] = stamps[:, 0] - upsample_factor // 2 150 | stamps[:, 1] = stamps[:, 1] + upsample_factor // 2 151 | stamps = stamps.astype(np.int32) 152 | return stamps 153 | -------------------------------------------------------------------------------- /sleeprnn/detection/predicted_dataset.py: -------------------------------------------------------------------------------- 1 | """mass_ss.py: Defines the MASS class that manipulates the MASS database.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | 9 | from sleeprnn.data.dataset import Dataset 10 | from sleeprnn.data.dataset import KEY_EEG, KEY_MARKS, KEY_N2_PAGES, KEY_ALL_PAGES 11 | from sleeprnn.helpers.reader import load_dataset 12 | from sleeprnn.common import constants, pkeys 13 | from sleeprnn.detection.feeder_dataset import FeederDataset 14 | from sleeprnn.detection.postprocessor import PostProcessor 15 | from sleeprnn.detection import postprocessing 16 | from sleeprnn.detection import det_utils 17 | 18 | 19 | class PredictedDataset(Dataset): 20 | 21 | def __init__( 22 | self, 23 | dataset: FeederDataset, 24 | probabilities_dict, 25 | params=None, 26 | verbose=False, 27 | skip_setting_threshold=False, 28 | ): 29 | # make the changes local 30 | params = {} if (params is None) else params.copy() 31 | # Force for the INTA case 32 | if "inta" in dataset.dataset_name: 33 | print("inta contained in dataset name '%s'" % dataset.dataset_name) 34 | print( 35 | "Overwriting SS postprocessing parameters for INTA dataset to sep 0.5, min 0.5, max 5.0" 36 | ) 37 | params[pkeys.SS_MIN_SEPARATION] = 0.5 38 | params[pkeys.SS_MIN_DURATION] = 0.5 39 | params[pkeys.SS_MAX_DURATION] = 5.0 40 | 41 | self.parent_dataset = dataset 42 | self.task_mode = dataset.task_mode 43 | self.probabilities_dict = probabilities_dict 44 | self.postprocessor = PostProcessor(event_name=dataset.event_name, params=params) 45 | self.probability_threshold = None 46 | 47 | """Constructor""" 48 | super(PredictedDataset, self).__init__( 49 | dataset_dir=dataset.dataset_dir, 50 | load_checkpoint=False, 51 | dataset_name="%s_predicted" % dataset.dataset_name, 52 | all_ids=dataset.all_ids, 53 | event_name=dataset.event_name, 54 | hypnogram_sleep_labels=dataset.hypnogram_sleep_labels, 55 | hypnogram_page_duration=dataset.hypnogram_page_duration, 56 | default_expert=1, 57 | default_page_subset=dataset.task_mode, 58 | n_experts=1, 59 | params=dataset.params.copy(), 60 | verbose=verbose, 61 | ) 62 | self.global_std = dataset.global_std 63 | # Check that subject ids in probabilities are the same as the ones 64 | # on the dataset 65 | ids_proba = list(self.probabilities_dict.keys()) 66 | ids_data = list(dataset.all_ids) 67 | ids_proba.sort() 68 | ids_data.sort() 69 | if ids_data != ids_proba: 70 | raise ValueError( 71 | "IDs mismatch: IDs from predictions are %s " 72 | "but IDs from given dataset are %s" % (ids_proba, ids_data) 73 | ) 74 | if not skip_setting_threshold: 75 | self.set_probability_threshold(0.5) 76 | 77 | def _load_from_source(self): 78 | """Loads the data from source.""" 79 | # Extract only necessary stuff 80 | data = {} 81 | for sub_id in self.all_ids: 82 | ind_dict = self.parent_dataset.read_subject_data(sub_id) 83 | pat_dict = { 84 | KEY_EEG: None, 85 | KEY_N2_PAGES: ind_dict[KEY_N2_PAGES], 86 | KEY_ALL_PAGES: ind_dict[KEY_ALL_PAGES], 87 | "%s_%d" % (KEY_MARKS, 1): None, 88 | } 89 | data[sub_id] = pat_dict 90 | self.parent_dataset = None 91 | return data 92 | 93 | def set_probability_threshold( 94 | self, new_probability_threshold, adjusted_by_threshold=None, verbose=False 95 | ): 96 | """Sets a new probability threshold and updates the stamps accordingly. 97 | 98 | If adjusted_by_threshold (float between 0 and 1) is set, then the given 99 | new probability threshold is treated as a threshold for probabilities ADJUSTED by the 100 | given value, i.e., probabilities that satisfy: 101 | 102 | adjusted_proba > 0.5 <=> predicted_proba > adjusted_by_threshold 103 | 104 | Therefore, the value of new_probability_threshold is first transformed to its equivalent 105 | in the predicted probabilities domain, and then it is applied to them. 106 | 107 | If adjusted_by_threshold is None, then the given new probability threshold is used directly 108 | on the predicted probabilities. 109 | """ 110 | if adjusted_by_threshold is not None: 111 | new_probability_threshold = ( 112 | det_utils.transform_thr_for_adjusted_to_thr_for_predicted( 113 | new_probability_threshold, adjusted_by_threshold 114 | ) 115 | ) 116 | ( 117 | print("New threshold: %1.8f" % new_probability_threshold) 118 | if verbose 119 | else None 120 | ) 121 | self.probability_threshold = new_probability_threshold 122 | self._update_stamps() 123 | 124 | def _update_stamps(self): 125 | probabilities_list = [] 126 | for sub_id in self.all_ids: 127 | # print("debug: adjusting proba") 128 | sub_proba = self.get_subject_probabilities(sub_id, return_adjusted=True) 129 | probabilities_list.append(sub_proba) 130 | 131 | if self.task_mode == constants.N2_RECORD: 132 | # Keep only N2 stamps 133 | n2_pages_val = self.get_pages(pages_subset=constants.N2_RECORD) 134 | else: 135 | n2_pages_val = None 136 | 137 | stamps_list = self.postprocessor.proba2stamps_with_list( 138 | probabilities_list, pages_indices_subset_list=n2_pages_val, thr=0.5 139 | ) # thr is 0.5 because probas are adjusted 140 | 141 | # KC postprocessing 142 | if self.event_name == constants.KCOMPLEX: 143 | signals = self.get_signals_external() 144 | new_stamps_list = [] 145 | for k, sub_id in enumerate(self.all_ids): 146 | # Load signal 147 | signal = signals[k] 148 | stamps = stamps_list[k] 149 | stamps = postprocessing.kcomplex_stamp_split( 150 | signal, stamps, self.fs, signal_is_filtered=True 151 | ) 152 | new_stamps_list.append(stamps) 153 | stamps_list = new_stamps_list 154 | 155 | # NSRR Amplitude removal 156 | if "nsrr" in self.parent_dataset.dataset_name: 157 | max_amplitude = 134.12087769782073 # uV, from MODA spindles 158 | new_stamps_list = [] 159 | for k, sub_id in enumerate(self.all_ids): 160 | # Load signal 161 | sub_data = self.parent_dataset.read_subject_data( 162 | sub_id, exclusion_of_pages=False 163 | ) 164 | signal = sub_data["signal"] 165 | stamps = stamps_list[k] 166 | if stamps.size > 0: 167 | stamps, no_peaks_found = postprocessing.spindle_amplitude_filtering( 168 | signal, stamps, self.fs, max_amplitude 169 | ) 170 | if no_peaks_found: 171 | print("Found error 'no peaks found' in subject %s" % sub_id) 172 | new_stamps_list.append(stamps) 173 | stamps_list = new_stamps_list 174 | 175 | # Now save model stamps 176 | stamp_key = "%s_%d" % (KEY_MARKS, 1) 177 | for k, sub_id in enumerate(self.all_ids): 178 | self.data[sub_id][stamp_key] = stamps_list[k] 179 | 180 | def get_subject_stamps_probabilities( 181 | self, 182 | subject_id, 183 | pages_subset=None, 184 | return_adjusted=True, 185 | proba_prc=75, 186 | ): 187 | subject_stamps = self.get_subject_stamps(subject_id, pages_subset=pages_subset) 188 | subject_proba = self.get_subject_probabilities( 189 | subject_id, return_adjusted=return_adjusted 190 | ) 191 | subject_stamp_proba = det_utils.get_event_probabilities( 192 | subject_stamps, 193 | subject_proba, 194 | downsampling_factor=self.params[pkeys.TOTAL_DOWNSAMPLING_FACTOR], 195 | proba_prc=proba_prc, 196 | ) 197 | return subject_stamp_proba 198 | 199 | def get_subset_stamps_probabilities( 200 | self, 201 | subject_ids, 202 | pages_subset=None, 203 | return_adjusted=True, 204 | proba_prc=75, 205 | ): 206 | stamp_proba_list = [] 207 | for sub_id in subject_ids: 208 | stamp_proba_list.append( 209 | self.get_subject_stamps_probabilities( 210 | sub_id, 211 | pages_subset=pages_subset, 212 | return_adjusted=return_adjusted, 213 | proba_prc=proba_prc, 214 | ) 215 | ) 216 | return stamp_proba_list 217 | 218 | def get_stamps_probabilities( 219 | self, 220 | pages_subset=None, 221 | return_adjusted=True, 222 | proba_prc=75, 223 | ): 224 | stamp_proba_list = self.get_subset_stamps_probabilities( 225 | self.all_ids, 226 | pages_subset=pages_subset, 227 | return_adjusted=return_adjusted, 228 | proba_prc=proba_prc, 229 | ) 230 | return stamp_proba_list 231 | 232 | def get_subject_probabilities(self, subject_id, return_adjusted=False): 233 | """Returns the subject's predicted probability vector. 234 | 235 | If return_adjusted is False (default), the predicted probabilities are returned. 236 | If return_adjusted is True, the predicted probabilities are first ADJUSTED by the 237 | set probability threshold, i.e., they are transformed to probabilities that satisfy: 238 | 239 | adjusted_proba > 0.5 <=> predicted_proba > probability_threshold 240 | 241 | after the adjustment, the adjusted probabilities are returned. 242 | """ 243 | subject_probabilities = self.probabilities_dict[subject_id].copy() 244 | if return_adjusted: 245 | subject_probabilities = ( 246 | det_utils.transform_predicted_proba_to_adjusted_proba( 247 | subject_probabilities, self.probability_threshold 248 | ) 249 | ) 250 | return subject_probabilities 251 | 252 | def get_subset_probabilities(self, subject_ids, return_adjusted=False): 253 | proba_list = [] 254 | for sub_id in subject_ids: 255 | proba_list.append( 256 | self.get_subject_probabilities(sub_id, return_adjusted=return_adjusted) 257 | ) 258 | return proba_list 259 | 260 | def get_probabilities(self, return_adjusted=False): 261 | return self.get_subset_probabilities( 262 | self.all_ids, return_adjusted=return_adjusted 263 | ) 264 | 265 | def set_parent_dataset(self, dataset): 266 | self.parent_dataset = dataset 267 | 268 | def delete_parent_dataset(self): 269 | self.parent_dataset = None 270 | 271 | def get_signals_external(self): 272 | if self.parent_dataset is None: 273 | tmp_name = self.dataset_name 274 | parent_dataset_name = "_".join(tmp_name.split("_")[:2]) 275 | parent_dataset = load_dataset( 276 | parent_dataset_name, 277 | params=self.params, 278 | load_checkpoint=True, 279 | verbose=False, 280 | ) 281 | else: 282 | parent_dataset = self.parent_dataset 283 | if not parent_dataset.exists_filt_signal_cache(): 284 | print("Creating cache that does not exist") 285 | parent_dataset.create_signal_cache() 286 | signals = parent_dataset.get_subset_filt_signals(self.all_ids) 287 | return signals 288 | -------------------------------------------------------------------------------- /sleeprnn/detection/simple_detection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sleeprnn.data import utils, stamp_correction 4 | 5 | 6 | def find_envelope(x, win_size): 7 | half_win_size = win_size // 2 8 | shifts = np.arange(-half_win_size, half_win_size + 0.001).astype(np.int32) 9 | shifted_signals = [] 10 | for shift in shifts: 11 | shifted_signals.append(np.roll(x, shift)) 12 | shifted_signals = np.stack(shifted_signals, axis=1) 13 | envelope = np.max(np.abs(shifted_signals), axis=1) 14 | return envelope 15 | 16 | 17 | def get_sigma_envelope(x, fs, lowcut=11, highcut=16, win_duration=0.1): 18 | signal_sigma = utils.apply_bandpass(x, fs, lowcut, highcut) 19 | win_size = int(fs * win_duration) 20 | signal_sigma_env = find_envelope(signal_sigma, win_size) 21 | return signal_sigma_env 22 | 23 | 24 | def simple_detector_from_envelope( 25 | signal_sigma_env, 26 | fs, 27 | amplitude_high_thr, 28 | amplitude_low_thr_factor=4 / 9, 29 | min_separation=0.3, 30 | min_duration_low=0.5, 31 | min_duration_high=0.3, 32 | max_duration=3.0, 33 | ): 34 | amplitude_low_thr = amplitude_high_thr * amplitude_low_thr_factor 35 | 36 | feat_high = (signal_sigma_env >= amplitude_high_thr).astype(np.int32) 37 | feat_low = (signal_sigma_env >= amplitude_low_thr).astype(np.int32) 38 | feat = feat_high + feat_low 39 | events = utils.seq2stamp(feat_low) # Candidates 40 | 41 | # Group candidate events closer than 0.3s (to remove small fluctuations) 42 | # and remove events shorter than 0.5s (so trivially meets criteria of duration in lower amplitude) 43 | events = stamp_correction.combine_close_stamps( 44 | events, fs, min_separation=min_separation 45 | ) 46 | events = stamp_correction.filter_duration_stamps( 47 | events, fs, min_duration=min_duration_low, max_duration=None 48 | ) 49 | 50 | # Criteria of higher amplitude 51 | min_duration_high = min_duration_high * fs 52 | new_events = [] 53 | for e in events: 54 | data = feat[e[0] : e[1] + 1] 55 | data_in_2 = np.sum(data == 2) 56 | if data_in_2 >= min_duration_high: 57 | new_events.append(e) 58 | if len(new_events) == 0: 59 | events = np.zeros((0, 2)).astype(np.int32) 60 | else: 61 | events = np.stack(new_events, axis=0) 62 | 63 | # Now remove events that are too long 64 | events = stamp_correction.filter_duration_stamps( 65 | events, fs, min_duration=min_duration_low, max_duration=max_duration 66 | ) 67 | 68 | return events 69 | 70 | 71 | def simple_detector_absolute( 72 | x, 73 | fs, 74 | amplitude_high_thr, 75 | amplitude_low_thr_factor=4 / 9, 76 | min_separation=0.3, 77 | min_duration_low=0.5, 78 | min_duration_high=0.3, 79 | max_duration=3.0, 80 | lowcut=11, 81 | highcut=16, 82 | win_duration=0.1, 83 | ): 84 | """Detection using absolute amplitudes""" 85 | signal_sigma_env = get_sigma_envelope( 86 | x, fs, lowcut=lowcut, highcut=highcut, win_duration=win_duration 87 | ) 88 | 89 | # detect 90 | events = simple_detector_from_envelope( 91 | signal_sigma_env, 92 | fs, 93 | amplitude_high_thr, 94 | amplitude_low_thr_factor=amplitude_low_thr_factor, 95 | min_separation=min_separation, 96 | min_duration_low=min_duration_low, 97 | min_duration_high=min_duration_high, 98 | max_duration=max_duration, 99 | ) 100 | return events 101 | 102 | 103 | def simple_detector_relative( 104 | x, 105 | fs, 106 | amplitude_high_factor, 107 | pages_to_compute_baseline, 108 | page_duration, 109 | amplitude_low_thr_factor=4 / 9, 110 | min_separation=0.3, 111 | min_duration_low=0.5, 112 | min_duration_high=0.3, 113 | max_duration=3.0, 114 | lowcut=11, 115 | highcut=16, 116 | win_duration=0.1, 117 | ): 118 | """Detection using amplitudes relative to baseline sigma activity""" 119 | signal_sigma_env = get_sigma_envelope( 120 | x, fs, lowcut=lowcut, highcut=highcut, win_duration=win_duration 121 | ) 122 | 123 | # compute baseline 124 | page_size = int(page_duration * fs) 125 | env_in_n2 = signal_sigma_env.reshape(-1, page_size)[pages_to_compute_baseline] 126 | mean_sigma_env = np.median(env_in_n2) 127 | # compute relative high thr 128 | amplitude_high_thr = amplitude_high_factor * mean_sigma_env 129 | 130 | # detect 131 | events = simple_detector_from_envelope( 132 | signal_sigma_env, 133 | fs, 134 | amplitude_high_thr, 135 | amplitude_low_thr_factor=amplitude_low_thr_factor, 136 | min_separation=min_separation, 137 | min_duration_low=min_duration_low, 138 | min_duration_high=min_duration_high, 139 | max_duration=max_duration, 140 | ) 141 | return events 142 | -------------------------------------------------------------------------------- /sleeprnn/detection/threshold_optimization.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from joblib import delayed, Parallel 6 | 7 | import numpy as np 8 | 9 | from sleeprnn.detection import metrics 10 | from sleeprnn.common import constants 11 | 12 | 13 | def fit_threshold( 14 | feeder_dataset_list, 15 | predicted_dataset_list, 16 | threshold_space, 17 | average_mode, 18 | return_best_af1=False, 19 | ): 20 | n_thr = int( 21 | np.floor( 22 | (threshold_space["max"] - threshold_space["min"]) / threshold_space["step"] 23 | + 1 24 | ) 25 | ) 26 | thr_list = np.array( 27 | [threshold_space["min"] + threshold_space["step"] * i for i in range(n_thr)] 28 | ) 29 | thr_list = np.round(thr_list, 2) 30 | 31 | events_list = [] 32 | for feeder_dataset in feeder_dataset_list: 33 | # Prepare expert labels 34 | this_events = feeder_dataset.get_stamps() 35 | events_list = events_list + this_events 36 | 37 | predictions_at_thr_list = [] 38 | for thr in thr_list: 39 | detections_list = [] 40 | for predicted_dataset in predicted_dataset_list: 41 | # Prepare model predictions 42 | predicted_dataset.set_probability_threshold(thr) 43 | this_detections = predicted_dataset.get_stamps() 44 | detections_list = detections_list + this_detections 45 | predictions_at_thr_list.append(detections_list) 46 | 47 | metric_fn_dict = { 48 | constants.MACRO_AVERAGE: metrics.average_metric_macro_average, 49 | constants.MICRO_AVERAGE: metrics.average_metric_micro_average, 50 | } 51 | metric_fn = metric_fn_dict[average_mode] 52 | af1_list = Parallel(n_jobs=-1)( 53 | delayed(metric_fn)(events_list, single_prediction_list) 54 | for single_prediction_list in predictions_at_thr_list 55 | ) 56 | max_idx = np.argmax(af1_list).item() 57 | best_thr = thr_list[max_idx] 58 | if return_best_af1: 59 | return best_thr, af1_list[max_idx] 60 | else: 61 | return best_thr 62 | 63 | 64 | def get_optimal_threshold( 65 | feeder_dataset_list, 66 | predicted_dataset_list, 67 | res_thr=0.02, 68 | start_thr=0.2, 69 | end_thr=0.8, 70 | verbose=False, 71 | ): 72 | 73 | # Check probability boundaries 74 | min_proba = start_thr # 0 75 | max_proba = end_thr # 1 76 | # for predicted_dataset in predicted_dataset_list: 77 | # this_proba_list = predicted_dataset.get_probabilities() 78 | # min_allowed = np.max([np.percentile(proba, 1) for proba in this_proba_list]) 79 | # max_allowed = np.min([np.percentile(proba, 99) for proba in this_proba_list]) 80 | # if min_allowed > min_proba: 81 | # min_proba = min_allowed 82 | # if max_allowed < max_proba: 83 | # max_proba = max_allowed 84 | # min_proba = np.ceil(100 * min_proba) / 100 85 | # max_proba = np.floor(100 * max_proba) / 100 86 | 87 | # Change start_thr and end_thr accordingly 88 | start_thr = max(min_proba, start_thr) 89 | end_thr = min(max_proba, end_thr) 90 | if verbose: 91 | print("Start thr: %1.4f. End thr: %1.4f" % (start_thr, end_thr)) 92 | n_thr = int(np.floor((end_thr - start_thr) / res_thr + 1)) 93 | thr_list = np.array([start_thr + res_thr * i for i in range(n_thr)]) 94 | thr_list = np.round(thr_list, 2) 95 | if verbose: 96 | print( 97 | "%d thresholds to be evaluated between %1.4f and %1.4f" 98 | % (n_thr, thr_list[0], thr_list[-1]) 99 | ) 100 | 101 | events_list = [] 102 | for feeder_dataset in feeder_dataset_list: 103 | # Prepare expert labels 104 | this_events = feeder_dataset.get_stamps() 105 | events_list = events_list + this_events 106 | 107 | predictions_at_thr_list = [] 108 | for thr in thr_list: 109 | detections_list = [] 110 | for predicted_dataset in predicted_dataset_list: 111 | # Prepare model predictions 112 | predicted_dataset.set_probability_threshold(thr) 113 | this_detections = predicted_dataset.get_stamps() 114 | detections_list = detections_list + this_detections 115 | predictions_at_thr_list.append(detections_list) 116 | 117 | af1_list = Parallel(n_jobs=-1)( 118 | delayed(metrics.average_metric_with_list)( 119 | events_list, single_prediction_list, verbose=False 120 | ) 121 | for single_prediction_list in predictions_at_thr_list 122 | ) 123 | 124 | # af1_list = [ 125 | # metrics.average_metric_with_list( 126 | # events_list, single_prediction_list, verbose=False) 127 | # for single_prediction_list in predictions_at_thr_list 128 | # ] 129 | 130 | # af1_list = [] 131 | # for thr in thr_list: 132 | # # events_list = [] 133 | # # detections_list = [] 134 | # # # for (feeder_dataset, predicted_dataset) in zip( 135 | # # # feeder_dataset_list, predicted_dataset_list): 136 | # # for predicted_dataset in predicted_dataset_list: 137 | # # # Prepare expert labels 138 | # # # this_events = feeder_dataset.get_stamps() 139 | # # # Prepare model predictions 140 | # # predicted_dataset.set_probability_threshold(thr) 141 | # # this_detections = predicted_dataset.get_stamps() 142 | # # # events_list = events_list + this_events 143 | # # detections_list = detections_list + this_detections 144 | # # Compute AF1 145 | # af1_at_thr = metrics.average_metric_with_list( 146 | # events_list, predictions_dict[thr], verbose=False) 147 | # af1_list.append(af1_at_thr) 148 | 149 | max_idx = np.argmax(af1_list).item() 150 | best_thr = thr_list[max_idx] 151 | return best_thr 152 | -------------------------------------------------------------------------------- /sleeprnn/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolasigor/Sleep-EEG-Event-Detector/24322278e3f3ef7535413a65a61fffc3ce4f4e01/sleeprnn/helpers/__init__.py -------------------------------------------------------------------------------- /sleeprnn/helpers/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sleeprnn.common import constants 4 | from sleeprnn.data import utils 5 | 6 | 7 | def get_inta_eeg_names(signal_names): 8 | result = [] 9 | for single_name in signal_names: 10 | if len(single_name) == 5 and single_name[2] == "-": 11 | result.append(single_name) 12 | return result 13 | 14 | 15 | def get_inta_eog_emg_names(signal_names): 16 | result = [] 17 | # Look for EOG 18 | possible_names = ["MOR", "ojo izquierdo"] 19 | for single_name in signal_names: 20 | if single_name in possible_names: 21 | result.append(single_name) 22 | break 23 | 24 | # Look for EMG 25 | possible_names = ["EMG", "EMG menton"] 26 | for single_name in signal_names: 27 | if single_name in possible_names: 28 | result.append(single_name) 29 | break 30 | return result 31 | 32 | 33 | def custom_linspace(start_value, end_value, step_value): 34 | n_points = int(np.round((end_value - start_value) / step_value)) 35 | array = start_value + np.arange(n_points + 1) * step_value 36 | return array 37 | 38 | 39 | def closest_index(single_value, array): 40 | return np.argmin((single_value - array) ** 2) 41 | 42 | 43 | def get_splits_dict(dataset, seed_id_list, use_test_set=True, train_fraction=0.75): 44 | ids_dict = {} 45 | for k in seed_id_list: 46 | train_ids, val_ids = utils.split_ids_list_v2( 47 | dataset.train_ids, split_id=k, train_fraction=train_fraction 48 | ) 49 | ids_dict[k] = {constants.TRAIN_SUBSET: train_ids, constants.VAL_SUBSET: val_ids} 50 | if use_test_set: 51 | ids_dict[k][constants.TEST_SUBSET] = dataset.test_ids 52 | return ids_dict 53 | -------------------------------------------------------------------------------- /sleeprnn/helpers/performer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sleeprnn.common import constants 4 | from sleeprnn.detection import metrics 5 | from sleeprnn.detection.feeder_dataset import FeederDataset 6 | from sleeprnn.helpers import misc 7 | 8 | 9 | def performance_vs_iou_with_seeds( 10 | dataset, 11 | predictions_dict, 12 | optimal_thr_list, 13 | iou_curve_axis, 14 | iou_hist_bins, 15 | task_mode, 16 | which_expert, 17 | set_name=constants.TEST_SUBSET, 18 | ): 19 | # Seeds 20 | seed_id_list = list(predictions_dict.keys()) 21 | seed_id_list.sort() 22 | ids_dict = misc.get_splits_dict(dataset, seed_id_list) 23 | # Performance 24 | tmp_f1_vs_iou = [] 25 | tmp_recall_vs_iou = [] 26 | tmp_precision_vs_iou = [] 27 | tmp_mean_af1 = [] 28 | tmp_mean_iou = [] 29 | tmp_iqr_low_iou = [] 30 | tmp_iqr_high_iou = [] 31 | tmp_iou_hist = [] 32 | for k in seed_id_list: 33 | # Expert 34 | subset_data = FeederDataset( 35 | dataset, ids_dict[k][set_name], task_mode, which_expert=which_expert 36 | ) 37 | events = subset_data.get_stamps() 38 | # Model 39 | prediction_data = predictions_dict[k][set_name] 40 | prediction_data.set_probability_threshold(optimal_thr_list[k]) 41 | detections = prediction_data.get_stamps() 42 | # Measure stuff 43 | results = performance_vs_iou(events, detections, iou_curve_axis, iou_hist_bins) 44 | tmp_f1_vs_iou.append(results[constants.F1_VS_IOU]) 45 | tmp_recall_vs_iou.append(results[constants.RECALL_VS_IOU]) 46 | tmp_precision_vs_iou.append(results[constants.PRECISION_VS_IOU]) 47 | tmp_mean_af1.append(results[constants.MEAN_AF1]) 48 | tmp_mean_iou.append(results[constants.MEAN_IOU]) 49 | tmp_iqr_low_iou.append(results[constants.IQR_LOW_IOU]) 50 | tmp_iqr_high_iou.append(results[constants.IQR_HIGH_IOU]) 51 | tmp_iou_hist.append(results[constants.IOU_HIST_VALUES]) 52 | tmp_f1_vs_iou = np.stack(tmp_f1_vs_iou, axis=0) 53 | tmp_recall_vs_iou = np.stack(tmp_recall_vs_iou, axis=0) 54 | tmp_precision_vs_iou = np.stack(tmp_precision_vs_iou, axis=0) 55 | tmp_mean_af1 = np.stack(tmp_mean_af1, axis=0) 56 | tmp_mean_iou = np.stack(tmp_mean_iou, axis=0) 57 | tmp_iqr_low_iou = np.stack(tmp_iqr_low_iou, axis=0) 58 | tmp_iqr_high_iou = np.stack(tmp_iqr_high_iou, axis=0) 59 | tmp_iou_hist = np.stack(tmp_iou_hist, axis=0) 60 | model_data_dict = { 61 | constants.F1_VS_IOU: tmp_f1_vs_iou, 62 | constants.RECALL_VS_IOU: tmp_recall_vs_iou, 63 | constants.PRECISION_VS_IOU: tmp_precision_vs_iou, 64 | constants.IOU_HIST_BINS: iou_hist_bins, 65 | constants.IOU_CURVE_AXIS: iou_curve_axis, 66 | constants.IOU_HIST_VALUES: tmp_iou_hist, 67 | constants.MEAN_IOU: tmp_mean_iou, 68 | constants.MEAN_AF1: tmp_mean_af1, 69 | constants.IQR_LOW_IOU: tmp_iqr_low_iou, 70 | constants.IQR_HIGH_IOU: tmp_iqr_high_iou, 71 | } 72 | return model_data_dict 73 | 74 | 75 | def performance_vs_iou(events, detections, iou_curve_axis, iou_hist_bins): 76 | # Matching 77 | iou_matchings, idx_matchings = metrics.matching_with_list(events, detections) 78 | # Measure stuff 79 | seed_f1_vs_iou = metrics.metric_vs_iou_with_list( 80 | events, 81 | detections, 82 | iou_curve_axis, 83 | iou_matching_list=iou_matchings, 84 | metric_name=constants.F1_SCORE, 85 | ) 86 | seed_recall_vs_iou = metrics.metric_vs_iou_with_list( 87 | events, 88 | detections, 89 | iou_curve_axis, 90 | iou_matching_list=iou_matchings, 91 | metric_name=constants.RECALL, 92 | ) 93 | seed_precision_vs_iou = metrics.metric_vs_iou_with_list( 94 | events, 95 | detections, 96 | iou_curve_axis, 97 | iou_matching_list=iou_matchings, 98 | metric_name=constants.PRECISION, 99 | ) 100 | 101 | seed_mean_af1 = metrics.average_metric_with_list( 102 | events, detections, iou_matching_list=iou_matchings 103 | ) 104 | seed_mean_iou = [] 105 | seed_iou_hist = [] 106 | seed_iqr_low_iou = [] 107 | seed_iqr_high_iou = [] 108 | for i in range(len(events)): 109 | iou_nonzero = iou_matchings[i][idx_matchings[i] > -1] 110 | iou_mean = np.mean(iou_nonzero) 111 | iou_low_iqr = np.percentile(iou_nonzero, 25) 112 | iou_high_iqr = np.percentile(iou_nonzero, 75) 113 | iou_hist, _ = np.histogram(iou_nonzero, bins=iou_hist_bins, density=True) 114 | seed_mean_iou.append(iou_mean) 115 | seed_iqr_low_iou.append(iou_low_iqr) 116 | seed_iqr_high_iou.append(iou_high_iqr) 117 | seed_iou_hist.append(iou_hist) 118 | seed_mean_iou = np.stack(seed_mean_iou, axis=0).mean(axis=0) 119 | seed_iqr_low_iou = np.stack(seed_iqr_low_iou, axis=0).mean(axis=0) 120 | seed_iqr_high_iou = np.stack(seed_iqr_high_iou, axis=0).mean(axis=0) 121 | seed_iou_hist = np.stack(seed_iou_hist, axis=0).mean(axis=0) 122 | results = { 123 | constants.F1_VS_IOU: seed_f1_vs_iou, 124 | constants.RECALL_VS_IOU: seed_recall_vs_iou, 125 | constants.PRECISION_VS_IOU: seed_precision_vs_iou, 126 | constants.IOU_HIST_BINS: iou_hist_bins, 127 | constants.IOU_CURVE_AXIS: iou_curve_axis, 128 | constants.IOU_HIST_VALUES: seed_iou_hist, 129 | constants.MEAN_IOU: seed_mean_iou, 130 | constants.MEAN_AF1: seed_mean_af1, 131 | constants.IQR_LOW_IOU: seed_iqr_low_iou, 132 | constants.IQR_HIGH_IOU: seed_iqr_high_iou, 133 | } 134 | return results 135 | 136 | 137 | def duration_scatter_with_seeds( 138 | dataset, 139 | predictions_dict, 140 | optimal_thr_list, 141 | task_mode, 142 | which_expert, 143 | set_name=constants.TEST_SUBSET, 144 | ): 145 | # Seeds 146 | seed_id_list = list(predictions_dict.keys()) 147 | seed_id_list.sort() 148 | ids_dict = misc.get_splits_dict(dataset, seed_id_list) 149 | # Performance 150 | results = {} 151 | for k in seed_id_list: 152 | # Expert 153 | subset_data = FeederDataset( 154 | dataset, ids_dict[k][set_name], task_mode, which_expert=which_expert 155 | ) 156 | events = subset_data.get_stamps() 157 | # Model 158 | prediction_data = predictions_dict[k][set_name] 159 | prediction_data.set_probability_threshold(optimal_thr_list[k]) 160 | detections = prediction_data.get_stamps() 161 | # Matching 162 | iou_matchings, idx_matchings = metrics.matching_with_list(events, detections) 163 | seed_matched_real_idx = [] 164 | seed_matched_det_idx = [] 165 | seed_matched_real_dur = [] 166 | seed_matched_det_dur = [] 167 | for i in range(len(events)): 168 | idx_matching = idx_matchings[i] 169 | matched_real_idx = np.where(idx_matching > -1)[0] 170 | matched_det_idx = idx_matching[idx_matching > -1] 171 | matched_real_event = events[i][matched_real_idx] 172 | matched_det_event = detections[i][matched_det_idx] 173 | matched_real_dur = matched_real_event[:, 1] - matched_real_event[:, 0] 174 | matched_det_dur = matched_det_event[:, 1] - matched_det_event[:, 0] 175 | seed_matched_real_idx.append(matched_real_idx) 176 | seed_matched_det_idx.append(matched_det_idx) 177 | seed_matched_real_dur.append(matched_real_dur) 178 | seed_matched_det_dur.append(matched_det_dur) 179 | results[k] = { 180 | "expert_idx": seed_matched_real_idx, 181 | "detection_idx": seed_matched_det_idx, 182 | "expert_duration": seed_matched_real_dur, 183 | "detection_duration": seed_matched_det_dur, 184 | } 185 | return results 186 | 187 | 188 | def precision_recall_curve_with_seeds( 189 | dataset, 190 | predictions_dict, 191 | pr_curve_thr, 192 | iou_thr, 193 | task_mode, 194 | which_expert, 195 | set_name=constants.TEST_SUBSET, 196 | ): 197 | # Seeds 198 | seed_id_list = list(predictions_dict.keys()) 199 | seed_id_list.sort() 200 | ids_dict = misc.get_splits_dict(dataset, seed_id_list) 201 | # Performance 202 | pr_curve = {} 203 | for k in seed_id_list: 204 | print("Processing seed %d" % k, flush=True) 205 | # Columns are [x: recall, y: precision] 206 | pr_curve[k] = { 207 | constants.RECALL: np.zeros(len(pr_curve_thr)), 208 | constants.PRECISION: np.zeros(len(pr_curve_thr)), 209 | } 210 | # Expert 211 | subset_data = FeederDataset( 212 | dataset, ids_dict[k][set_name], task_mode, which_expert=which_expert 213 | ) 214 | events = subset_data.get_stamps() 215 | # Model 216 | prediction_data = predictions_dict[k][set_name] 217 | for i, thr in enumerate(pr_curve_thr): 218 | prediction_data.set_probability_threshold(thr) 219 | detections = prediction_data.get_stamps() 220 | # Measure stuff 221 | this_stats = [ 222 | metrics.by_event_confusion(this_y, this_y_pred, iou_thr=iou_thr) 223 | for (this_y, this_y_pred) in zip(events, detections) 224 | ] 225 | this_recall = np.mean([m[constants.RECALL] for m in this_stats]) 226 | this_precision = np.mean([m[constants.PRECISION] for m in this_stats]) 227 | pr_curve[k][constants.RECALL][i] = this_recall 228 | pr_curve[k][constants.PRECISION][i] = this_precision 229 | return pr_curve 230 | -------------------------------------------------------------------------------- /sleeprnn/helpers/printer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sleeprnn.helpers import misc 4 | from sleeprnn.common import constants 5 | 6 | 7 | def print_available_ckpt(optimal_thr_for_ckpt_dict, filter_dates, file=None): 8 | if filter_dates[0] is None: 9 | filter_dates[0] = -1 10 | if filter_dates[1] is None: 11 | filter_dates[1] = 1e12 12 | print("Available ckpt:") 13 | for key in optimal_thr_for_ckpt_dict.keys(): 14 | key_date = int(key.split("_")[0]) 15 | if filter_dates[0] <= key_date <= filter_dates[1]: 16 | if file is not None: 17 | print(" %s" % key, file=file) 18 | print(" %s" % key) 19 | 20 | 21 | def print_performance_at_iou( 22 | performance_data_dict, 23 | iou_thr, 24 | label, 25 | file=None, 26 | decimal_precision=1, 27 | show_iqr_iou=True, 28 | ): 29 | str_to_print = "%%%d.%df \u00B1 %%%d.%df" % ( 30 | 1, 31 | decimal_precision, 32 | 1, 33 | decimal_precision, 34 | ) 35 | 36 | str_to_print_iqr = "%%%d.%df \u00B1 %%%d.%df [%%%d.%df - %%%d.%df]" % ( 37 | 1, 38 | decimal_precision, 39 | 1, 40 | decimal_precision, 41 | 1, 42 | decimal_precision, 43 | 1, 44 | decimal_precision, 45 | ) 46 | 47 | iou_curve_axis = performance_data_dict[constants.IOU_CURVE_AXIS] 48 | idx_to_show = misc.closest_index(iou_thr, iou_curve_axis) 49 | 50 | mean_f1_vs_iou = performance_data_dict[constants.F1_VS_IOU].mean(axis=0) 51 | std_f1_vs_iou = performance_data_dict[constants.F1_VS_IOU].std(axis=0) 52 | msg = "F1 " + str_to_print % ( 53 | 100 * mean_f1_vs_iou[idx_to_show], 54 | 100 * std_f1_vs_iou[idx_to_show], 55 | ) 56 | 57 | mean_rec_vs_iou = performance_data_dict[constants.RECALL_VS_IOU].mean(axis=0) 58 | std_rec_vs_iou = performance_data_dict[constants.RECALL_VS_IOU].std(axis=0) 59 | msg = ( 60 | msg 61 | + ", Recall " 62 | + str_to_print 63 | % (100 * mean_rec_vs_iou[idx_to_show], 100 * std_rec_vs_iou[idx_to_show]) 64 | ) 65 | 66 | mean_pre_vs_iou = performance_data_dict[constants.PRECISION_VS_IOU].mean(axis=0) 67 | std_pre_vs_iou = performance_data_dict[constants.PRECISION_VS_IOU].std(axis=0) 68 | msg = ( 69 | msg 70 | + ", Precision " 71 | + str_to_print 72 | % (100 * mean_pre_vs_iou[idx_to_show], 100 * std_pre_vs_iou[idx_to_show]) 73 | ) 74 | 75 | msg = ( 76 | msg 77 | + ", AF1 " 78 | + str_to_print 79 | % ( 80 | 100 * performance_data_dict[constants.MEAN_AF1].mean(), 81 | 100 * performance_data_dict[constants.MEAN_AF1].std(), 82 | ) 83 | ) 84 | 85 | if show_iqr_iou: 86 | msg = ( 87 | msg 88 | + ", IoU " 89 | + str_to_print_iqr 90 | % ( 91 | 100 * performance_data_dict[constants.MEAN_IOU].mean(), 92 | 100 * performance_data_dict[constants.MEAN_IOU].std(), 93 | 100 * performance_data_dict[constants.IQR_LOW_IOU].mean(), 94 | 100 * performance_data_dict[constants.IQR_HIGH_IOU].mean(), 95 | ) 96 | ) 97 | else: 98 | msg = ( 99 | msg 100 | + ", IoU " 101 | + str_to_print 102 | % ( 103 | 100 * performance_data_dict[constants.MEAN_IOU].mean(), 104 | 100 * performance_data_dict[constants.MEAN_IOU].std(), 105 | ) 106 | ) 107 | 108 | msg = msg + " for %s" % label 109 | if file is not None: 110 | print(msg, file=file) 111 | print(msg) 112 | 113 | 114 | def print_formatted_performance_at_iou( 115 | performance_data_dict, 116 | iou_thr, 117 | label, 118 | print_header=True, 119 | decimal_precision=1, 120 | file=None, 121 | show_iqr_iou=True, 122 | ): 123 | iou_curve_axis = performance_data_dict[constants.IOU_CURVE_AXIS] 124 | idx_to_show = misc.closest_index(iou_thr, iou_curve_axis) 125 | 126 | str_to_print = "%%%d.%df \u00B1 %%%d.%df" % ( 127 | 1, 128 | decimal_precision, 129 | 1, 130 | decimal_precision, 131 | ) 132 | 133 | str_to_print_iqr = "%%%d.%df \u00B1 %%%d.%df [%%%d.%df - %%%d.%df]" % ( 134 | 1, 135 | decimal_precision, 136 | 1, 137 | decimal_precision, 138 | 1, 139 | decimal_precision, 140 | 1, 141 | decimal_precision, 142 | ) 143 | 144 | if print_header: 145 | if file is not None: 146 | print("Model; F1; Recall; Precision; AF1; IoU", file=file) 147 | print("Model; F1; Recall; Precision; AF1; IoU") 148 | 149 | msg = label 150 | 151 | mean_f1_vs_iou = performance_data_dict[constants.F1_VS_IOU].mean(axis=0) 152 | std_f1_vs_iou = performance_data_dict[constants.F1_VS_IOU].std(axis=0) 153 | msg = ( 154 | msg 155 | + "; " 156 | + str_to_print 157 | % (100 * mean_f1_vs_iou[idx_to_show], 100 * std_f1_vs_iou[idx_to_show]) 158 | ) 159 | 160 | mean_rec_vs_iou = performance_data_dict[constants.RECALL_VS_IOU].mean(axis=0) 161 | std_rec_vs_iou = performance_data_dict[constants.RECALL_VS_IOU].std(axis=0) 162 | msg = ( 163 | msg 164 | + "; " 165 | + str_to_print 166 | % (100 * mean_rec_vs_iou[idx_to_show], 100 * std_rec_vs_iou[idx_to_show]) 167 | ) 168 | 169 | mean_pre_vs_iou = performance_data_dict[constants.PRECISION_VS_IOU].mean(axis=0) 170 | std_pre_vs_iou = performance_data_dict[constants.PRECISION_VS_IOU].std(axis=0) 171 | msg = ( 172 | msg 173 | + "; " 174 | + str_to_print 175 | % (100 * mean_pre_vs_iou[idx_to_show], 100 * std_pre_vs_iou[idx_to_show]) 176 | ) 177 | 178 | msg = ( 179 | msg 180 | + "; " 181 | + str_to_print 182 | % ( 183 | 100 * performance_data_dict[constants.MEAN_AF1].mean(), 184 | 100 * performance_data_dict[constants.MEAN_AF1].std(), 185 | ) 186 | ) 187 | 188 | if show_iqr_iou: 189 | msg = ( 190 | msg 191 | + "; " 192 | + str_to_print_iqr 193 | % ( 194 | 100 * performance_data_dict[constants.MEAN_IOU].mean(), 195 | 100 * performance_data_dict[constants.MEAN_IOU].std(), 196 | 100 * performance_data_dict[constants.IQR_LOW_IOU].mean(), 197 | 100 * performance_data_dict[constants.IQR_HIGH_IOU].mean(), 198 | ) 199 | ) 200 | else: 201 | msg = ( 202 | msg 203 | + "; " 204 | + str_to_print 205 | % ( 206 | 100 * performance_data_dict[constants.MEAN_IOU].mean(), 207 | 100 * performance_data_dict[constants.MEAN_IOU].std(), 208 | ) 209 | ) 210 | if file is not None: 211 | print(msg, file=file) 212 | print(msg) 213 | -------------------------------------------------------------------------------- /sleeprnn/helpers/sharing.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from scipy.interpolate import interp1d 5 | from scipy.signal import resample_poly, butter, filtfilt, find_peaks 6 | 7 | 8 | def split_mass(subject_ids, split_id, train_fraction=0.75, verbose=False): 9 | """Subject ids is the sorted list of non-testing ids.""" 10 | n_subjects = len(subject_ids) 11 | n_train = int(n_subjects * train_fraction) 12 | if verbose: 13 | print("Split IDs: Total %d -- Training %d" % (n_subjects, n_train)) 14 | n_val = n_subjects - n_train 15 | start_idx = split_id * n_val 16 | epoch = int(start_idx / n_subjects) 17 | attempts = 1 18 | while True: 19 | random_idx_1 = np.random.RandomState(seed=epoch).permutation(n_subjects) 20 | random_idx_2 = np.random.RandomState(seed=epoch + attempts).permutation( 21 | n_subjects 22 | ) 23 | random_idx = np.concatenate([random_idx_1, random_idx_2]) 24 | start_idx_relative = start_idx % n_subjects 25 | val_idx = random_idx[start_idx_relative : (start_idx_relative + n_val)] 26 | if np.unique(val_idx).size == n_val: 27 | break 28 | else: 29 | print("Attempting new split due to replication in val set") 30 | attempts = attempts + 1000 31 | 32 | val_ids = [subject_ids[i] for i in val_idx] 33 | train_ids = [sub_id for sub_id in subject_ids if sub_id not in val_ids] 34 | val_ids.sort() 35 | train_ids.sort() 36 | return train_ids, val_ids 37 | 38 | 39 | def preprocess_mass_signals(signal, fs_original, fs_target=200): 40 | """Bandpass filtering and resampling to desired sampling frequency.""" 41 | # ###### 42 | # Particular fix for mass dataset: 43 | fs_old_round = int(np.round(fs_original)) 44 | # Transform the original fs frequency with decimals to rounded version 45 | signal = resample_signal_linear(signal, fs_old=fs_original, fs_new=fs_old_round) 46 | # ###### 47 | 48 | # Broad bandpass filter to signal 49 | signal = broad_filter(signal, fs_old_round) 50 | 51 | # Now resample to the required frequency 52 | if fs_target != fs_old_round: 53 | print("Resampling from %d Hz to required %d Hz" % (fs_old_round, fs_target)) 54 | signal = resample_signal(signal, fs_old=fs_old_round, fs_new=fs_target) 55 | else: 56 | print("Signal already at required %d Hz" % fs_target) 57 | 58 | signal = signal.astype(np.float32) 59 | return signal 60 | 61 | 62 | def postprocess_mass_detections( 63 | signal, 64 | fs, 65 | detections, 66 | is_kcomplex, 67 | min_separation=0.3, 68 | min_duration=0.3, 69 | max_duration=3.0, 70 | repair_long=True, 71 | ): 72 | """detections is (n_detections, 2) array""" 73 | detections = combine_close_stamps(detections, fs, min_separation) 74 | detections = filter_duration_stamps( 75 | detections, fs, min_duration, max_duration, repair_long=repair_long 76 | ) 77 | if is_kcomplex: 78 | # For K-Complexes we perform the splitting procedure 79 | detections = kcomplex_stamp_split( 80 | signal, detections, fs, signal_is_filtered=False 81 | ) 82 | return detections 83 | 84 | 85 | # ########################### 86 | # From here, the functions used by the above ones are defined. 87 | 88 | 89 | def resample_signal(signal, fs_old, fs_new): 90 | """Returns resampled signal, from fs_old Hz to fs_new Hz.""" 91 | gcd_freqs = math.gcd(fs_new, fs_old) 92 | up = int(fs_new / gcd_freqs) 93 | down = int(fs_old / gcd_freqs) 94 | signal = resample_poly(signal, up, down) 95 | signal = np.array(signal, dtype=np.float32) 96 | return signal 97 | 98 | 99 | def resample_signal_linear(signal, fs_old, fs_new): 100 | """Returns resampled signal, from fs_old Hz to fs_new Hz. 101 | 102 | This implementation uses simple linear interpolation to achieve this. 103 | """ 104 | t = np.cumsum(np.ones(len(signal)) / fs_old) 105 | t_new = np.arange(t[0], t[-1], 1 / fs_new) 106 | signal = interp1d(t, signal)(t_new) 107 | return signal 108 | 109 | 110 | def broad_filter(signal, fs, lowcut=0.1, highcut=35): 111 | """Returns filtered signal sampled at fs Hz, with a [lowcut, highcut] Hz 112 | bandpass.""" 113 | # Generate butter bandpass of order 3. 114 | nyq = 0.5 * fs 115 | low = lowcut / nyq 116 | high = highcut / nyq 117 | b, a = butter(3, (low, high), btype="band") 118 | # Apply filter to the signal with zero-phase. 119 | filtered_signal = filtfilt(b, a, signal) 120 | return filtered_signal 121 | 122 | 123 | def combine_close_stamps(marks, fs, min_separation): 124 | """Combines contiguous marks that are too close to each other. Marks are 125 | assumed to be sample-stamps. 126 | 127 | If min_separation is None, the functionality is bypassed. 128 | """ 129 | if marks.size == 0: 130 | return marks 131 | 132 | if min_separation is None: 133 | combined_marks = marks 134 | else: 135 | marks = np.sort(marks, axis=0) 136 | combined_marks = [marks[0, :]] 137 | for i in range(1, marks.shape[0]): 138 | last_mark = combined_marks[-1] 139 | this_mark = marks[i, :] 140 | gap = (this_mark[0] - last_mark[1]) / fs 141 | if gap < min_separation: 142 | # Combine mark, so the last mark ends in the maximum ending. 143 | combined_marks[-1][1] = max(this_mark[1], combined_marks[-1][1]) 144 | else: 145 | combined_marks.append(this_mark) 146 | combined_marks = np.stack(combined_marks, axis=0) 147 | return combined_marks 148 | 149 | 150 | def filter_duration_stamps(marks, fs, min_duration, max_duration, repair_long=True): 151 | """Removes marks that are too short or strangely long. Marks longer than 152 | max_duration but not strangely long are cropped to keep the central 153 | max_duration duration if repair_long is True. 154 | Durations are assumed to be in seconds. 155 | Marks are assumed to be sample-stamps. 156 | 157 | If min_duration is None, no short marks are removed. 158 | If max_duration is None, no long marks are removed. 159 | """ 160 | if marks.size == 0: 161 | return marks 162 | 163 | if min_duration is None and max_duration is None: 164 | return marks 165 | else: 166 | durations = (marks[:, 1] - marks[:, 0] + 1) / fs 167 | 168 | if min_duration is not None: 169 | # Remove too short spindles 170 | feasible_idx = np.where(durations >= min_duration)[0] 171 | marks = marks[feasible_idx, :] 172 | durations = durations[feasible_idx] 173 | 174 | if max_duration is not None: 175 | 176 | if repair_long: 177 | # Remove weird annotations (extremely long) 178 | feasible_idx = np.where(durations <= 2 * max_duration)[0] 179 | marks = marks[feasible_idx, :] 180 | durations = durations[feasible_idx] 181 | 182 | # For annotations with durations longer than max_duration, 183 | # keep the central seconds 184 | excess = durations - max_duration 185 | excess = np.clip(excess, 0, None) 186 | half_remove = ((fs * excess + 1) / 2).astype(np.int32) 187 | half_remove_array = np.stack([half_remove, -half_remove], axis=1) 188 | marks = marks + half_remove_array 189 | else: 190 | # No repairing, simply remove 191 | feasible_idx = np.where(durations <= max_duration)[0] 192 | marks = marks[feasible_idx, :] 193 | return marks 194 | 195 | 196 | def kcomplex_stamp_split( 197 | signal, 198 | stamps, 199 | fs, 200 | highcut=4, 201 | left_edge_tol=0.05, 202 | right_edge_tol=0.2, 203 | signal_is_filtered=False, 204 | ): 205 | left_edge_tol = fs * left_edge_tol 206 | right_edge_tol = fs * right_edge_tol 207 | 208 | if signal_is_filtered: 209 | filt_signal = signal 210 | else: 211 | filt_signal = filter_iir_lowpass(signal, fs, highcut=highcut) 212 | 213 | new_stamps = [] 214 | for stamp in stamps: 215 | stamp_size = stamp[1] - stamp[0] + 1 216 | filt_in_stamp = filt_signal[stamp[0] : stamp[1]] 217 | negative_peaks, _ = find_peaks(-filt_in_stamp) 218 | # peaks needs to be negative 219 | negative_peaks = [peak for peak in negative_peaks if filt_in_stamp[peak] < 0] 220 | 221 | negative_peaks = [ 222 | peak 223 | for peak in negative_peaks 224 | if left_edge_tol < peak < stamp_size - right_edge_tol 225 | ] 226 | 227 | n_peaks = len(negative_peaks) 228 | if n_peaks > 1: 229 | # Change of sign filtering 230 | group_peaks = [[negative_peaks[0]]] 231 | idx_group = 0 232 | for i in range(1, len(negative_peaks)): 233 | last_peak = group_peaks[idx_group][-1] 234 | this_peak = negative_peaks[i] 235 | signal_between_peaks = filt_in_stamp[last_peak:this_peak] 236 | min_value = signal_between_peaks.min() 237 | max_value = signal_between_peaks.max() 238 | if min_value < 0 < max_value: 239 | # there is a change of sign, so it is a new group 240 | group_peaks.append([this_peak]) 241 | idx_group = idx_group + 1 242 | else: 243 | # Now change of sign, same group 244 | group_peaks[idx_group].append(this_peak) 245 | new_peaks = [] 246 | for single_group in group_peaks: 247 | new_peaks.append(int(np.mean(single_group))) 248 | negative_peaks = new_peaks 249 | 250 | n_peaks = len(negative_peaks) 251 | if n_peaks > 1: 252 | # Split marks 253 | edges_list = [stamp[0]] 254 | for i in range(n_peaks - 1): 255 | split_point_rel = (negative_peaks[i] + negative_peaks[i + 1]) / 2 256 | split_point_abs = int(stamp[0] + split_point_rel) 257 | edges_list.append(split_point_abs) 258 | edges_list.append(stamp[1]) 259 | for i in range(len(edges_list) - 1): 260 | new_stamps.append([edges_list[i], edges_list[i + 1]]) 261 | else: 262 | new_stamps.append(stamp) 263 | if len(new_stamps) > 0: 264 | new_stamps = np.stack(new_stamps, axis=0).astype(np.int32) 265 | else: 266 | new_stamps = np.zeros((0, 2)).astype(np.int32) 267 | return new_stamps 268 | 269 | 270 | def filter_iir_lowpass(signal, fs, highcut=4): 271 | """Returns filtered signal sampled at fs Hz, with a highcut Hz 272 | lowpass.""" 273 | # Generate butter bandpass of order 3. 274 | nyq = 0.5 * fs 275 | high = highcut / nyq 276 | b, a = butter(3, high, btype="low") 277 | # Apply filter to the signal with zero-phase. 278 | filtered_signal = filtfilt(b, a, signal) 279 | return filtered_signal 280 | -------------------------------------------------------------------------------- /sleeprnn/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolasigor/Sleep-EEG-Event-Detector/24322278e3f3ef7535413a65a61fffc3ce4f4e01/sleeprnn/nn/__init__.py -------------------------------------------------------------------------------- /sleeprnn/nn/feeding.py: -------------------------------------------------------------------------------- 1 | """Module that defines input pipeline operations.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def get_iterator_splitted( 11 | tensors_ph_1, 12 | tensors_ph_2, 13 | batch_size, 14 | repeat=True, 15 | shuffle_buffer_size=0, 16 | map_fn=None, 17 | prefetch_buffer_size=0, 18 | name=None, 19 | ): 20 | with tf.name_scope(name): 21 | with tf.device("/cpu:0"): 22 | 23 | batch_size_1 = int(batch_size / 2) 24 | batch_size_2 = batch_size - batch_size_1 25 | 26 | # First dataset 27 | dataset_1 = tf.data.Dataset.from_tensor_slices(tensors_ph_1) 28 | if shuffle_buffer_size > 0: 29 | dataset_1 = dataset_1.shuffle(buffer_size=shuffle_buffer_size) 30 | if repeat: 31 | dataset_1 = dataset_1.repeat() 32 | if map_fn is not None: 33 | dataset_1 = dataset_1.map(map_fn) 34 | dataset_1 = dataset_1.batch(batch_size=batch_size_1) 35 | if prefetch_buffer_size > 0: 36 | dataset_1 = dataset_1.prefetch(buffer_size=prefetch_buffer_size) 37 | 38 | # Second dataset 39 | dataset_2 = tf.data.Dataset.from_tensor_slices(tensors_ph_2) 40 | if shuffle_buffer_size > 0: 41 | dataset_2 = dataset_2.shuffle(buffer_size=shuffle_buffer_size) 42 | if repeat: 43 | dataset_2 = dataset_2.repeat() 44 | if map_fn is not None: 45 | dataset_2 = dataset_2.map(map_fn) 46 | dataset_2 = dataset_2.batch(batch_size=batch_size_2) 47 | if prefetch_buffer_size > 0: 48 | dataset_2 = dataset_2.prefetch(buffer_size=prefetch_buffer_size) 49 | 50 | # Zip datasets 51 | dataset = tf.data.Dataset.zip((dataset_1, dataset_2)) 52 | dataset = dataset.map(_combine_batch_fn) 53 | if prefetch_buffer_size > 0: 54 | dataset = dataset.prefetch(buffer_size=prefetch_buffer_size) 55 | 56 | iterator = dataset.make_initializable_iterator() 57 | return iterator 58 | 59 | 60 | def _combine_batch_fn(tensors_1, tensors_2): 61 | """Takes a tuple of tensors from two sources and concatenates them 62 | along the batch dimension to form a single tuple.""" 63 | n_tensors = len(tensors_1) 64 | 65 | combined_tensors = [] 66 | for k in range(n_tensors): 67 | tensor_from_1 = tensors_1[k] 68 | tensor_from_2 = tensors_2[k] 69 | this_combined = tf.concat([tensor_from_1, tensor_from_2], axis=0) 70 | combined_tensors.append(this_combined) 71 | combined_tensors = tuple(combined_tensors) 72 | return combined_tensors 73 | 74 | 75 | def get_iterator( 76 | tensors_ph, 77 | batch_size, 78 | repeat=True, 79 | shuffle_buffer_size=0, 80 | map_fn=None, 81 | prefetch_buffer_size=0, 82 | name=None, 83 | ): 84 | """Builds efficient iterators for the training loop. 85 | 86 | Args: 87 | tensors_ph: (tensor) Input tensors placeholders 88 | batch_size: (int) Size of the minibatches 89 | repeat: (optional, boolean, defaults to True) whether to repeat 90 | ad infinitum the dataset or not. 91 | shuffle_buffer_size: (Optional, int, defaults to 0) Size of the buffer 92 | to shuffle the data. If 0, no shuffle is applied. 93 | map_fn: (Optional, function, defaults to None) A function that 94 | preprocess the features and labels before passing them to the model. 95 | prefetch_buffer_size: (Optional, int, defaults to 0) Size of the buffer 96 | to prefetch the batches. If 0, no prefetch is applied. 97 | name: (Optional, string, defaults to None) Name for the operation. 98 | """ 99 | with tf.name_scope(name): 100 | with tf.device("/cpu:0"): 101 | dataset = tf.data.Dataset.from_tensor_slices(tensors_ph) 102 | if shuffle_buffer_size > 0: 103 | dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) 104 | if repeat: 105 | dataset = dataset.repeat() 106 | if map_fn is not None: 107 | dataset = dataset.map(map_fn) 108 | dataset = dataset.batch(batch_size=batch_size) 109 | if prefetch_buffer_size > 0: 110 | dataset = dataset.prefetch(buffer_size=prefetch_buffer_size) 111 | iterator = dataset.make_initializable_iterator() 112 | return iterator 113 | 114 | 115 | def get_global_iterator(handle_ph, iterators_list, name=None): 116 | """Builds a global iterator that can switch between two iterators. 117 | 118 | Args: 119 | handle_ph: (Tensor) Placeholder of type tf.string and shape [] that 120 | will be fed with the proper string_handle. 121 | iterators_list: (list of Iterator) List of the iterators from where we 122 | can obtain inputs. 123 | name: (Optional, string, defaults to None) Name for the operation. 124 | 125 | Returns: 126 | global_iterator: (Iterator) Iterator that will switch between iterator_1 127 | and iterator_2 according to the handle fed to handle_ph. 128 | """ 129 | with tf.name_scope(name): 130 | with tf.device("/cpu:0"): 131 | global_iterator = tf.data.Iterator.from_string_handle( 132 | handle_ph, 133 | iterators_list[0].output_types, 134 | iterators_list[0].output_shapes, 135 | ) 136 | return global_iterator 137 | -------------------------------------------------------------------------------- /sleeprnn/nn/metrics.py: -------------------------------------------------------------------------------- 1 | """Module that defines useful metrics to monitor a model.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | from sleeprnn.common import constants 10 | 11 | 12 | def confusion_matrix(logits, labels, masks=None): 13 | """Returns TP, FP and FN""" 14 | with tf.variable_scope("confusion_matrix"): 15 | predictions_sparse = tf.argmax(logits, axis=-1) 16 | labels_zero = tf.equal(labels, tf.zeros_like(labels)) 17 | labels_one = tf.equal(labels, tf.ones_like(labels)) 18 | predictions_zero = tf.equal( 19 | predictions_sparse, tf.zeros_like(predictions_sparse) 20 | ) 21 | predictions_one = tf.equal(predictions_sparse, tf.ones_like(predictions_sparse)) 22 | 23 | valid_samples = 1.0 if masks is None else tf.cast(masks, tf.float32) 24 | 25 | tp_samples = tf.cast(tf.logical_and(labels_one, predictions_one), tf.float32) 26 | tp = tf.reduce_sum(valid_samples * tp_samples) 27 | 28 | fp_samples = tf.cast(tf.logical_and(labels_zero, predictions_one), tf.float32) 29 | fp = tf.reduce_sum(valid_samples * fp_samples) 30 | 31 | fn_samples = tf.cast(tf.logical_and(labels_one, predictions_zero), tf.float32) 32 | fn = tf.reduce_sum(valid_samples * fn_samples) 33 | 34 | return tp, fp, fn 35 | 36 | 37 | def precision_recall_f1score(tp, fp, fn): 38 | """Return Precision, Recall, and F1-Score metrics.""" 39 | with tf.variable_scope("precision"): 40 | # Edge case: no detections -> precision 1 41 | precision = tf.cond( 42 | pred=tf.equal((tp + fp), 0), 43 | true_fn=lambda: 1.0, 44 | false_fn=lambda: tp / (tp + fp), 45 | ) 46 | with tf.variable_scope("recall"): 47 | # Edge case: no marks -> recall 1 48 | recall = tf.cond( 49 | pred=tf.equal((tp + fn), 0), 50 | true_fn=lambda: 1.0, 51 | false_fn=lambda: tp / (tp + fn), 52 | ) 53 | with tf.variable_scope("f1_score"): 54 | # Edge case: precision and recall 0 -> f1 score 0 55 | f1_score = tf.cond( 56 | pred=tf.equal((2 * tp + fn + fp), 0), 57 | true_fn=lambda: 0.0, 58 | false_fn=lambda: 2 * tp / (2 * tp + fn + fp), 59 | ) 60 | return precision, recall, f1_score 61 | -------------------------------------------------------------------------------- /sleeprnn/nn/optimizers.py: -------------------------------------------------------------------------------- 1 | """Module that defines optimizers to train a model.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | from sleeprnn.common import constants 10 | from sleeprnn.nn import adam_w 11 | 12 | 13 | def generic_optimizer_fn(optimizer, loss, clip_norm): 14 | """Applies the optimizer to the loss.""" 15 | 16 | if type(optimizer) == adam_w.AdamW: 17 | train_vars = tf.trainable_variables() 18 | grads = optimizer.get_gradients(loss, train_vars) 19 | original_gvs = [(grad, var) for grad, var in zip(grads, train_vars)] 20 | else: 21 | original_gvs = optimizer.compute_gradients(loss) 22 | 23 | gradients = [gv[0] for gv in original_gvs] 24 | grad_norm = tf.global_norm(gradients, name="gradient_norm") 25 | grad_norm_summ = tf.summary.scalar("original_grad_norm", grad_norm) 26 | 27 | if clip_norm is not None: 28 | gradients, _ = tf.clip_by_global_norm( 29 | gradients, clip_norm, use_norm=grad_norm, name="clipping" 30 | ) 31 | clipped_grad_norm = tf.global_norm(gradients, name="new_gradient_norm") 32 | variables = [gv[1] for gv in original_gvs] 33 | new_gvs = [(grad, var) for grad, var in zip(gradients, variables)] 34 | clipped_grad_norm_summ = tf.summary.scalar( 35 | "clipped_grad_norm", clipped_grad_norm 36 | ) 37 | grad_norm_summ = tf.summary.merge([grad_norm_summ, clipped_grad_norm_summ]) 38 | else: 39 | new_gvs = original_gvs 40 | 41 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # For BN 42 | with tf.control_dependencies(update_ops): 43 | train_step = optimizer.apply_gradients(new_gvs) 44 | reset_optimizer_op = tf.variables_initializer(optimizer.variables()) 45 | return train_step, reset_optimizer_op, grad_norm_summ 46 | 47 | 48 | def adam_optimizer_fn(loss, learning_rate, clip_norm): 49 | """Returns the optimizer operation to minimize the loss with Adam. 50 | 51 | Args: 52 | loss: (tensor) loss to be minimized 53 | learning_rate: (float) learning rate for the optimizer 54 | clip_norm: (float) Global norm to clip. 55 | """ 56 | with tf.name_scope(constants.ADAM_OPTIMIZER): 57 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 58 | return generic_optimizer_fn(optimizer, loss, clip_norm) 59 | 60 | 61 | def adam_w_optimizer_fn(loss, learning_rate, weight_decay, clip_norm): 62 | """Returns the optimizer operation to minimize the loss with Adam W. 63 | 64 | Args: 65 | loss: (tensor) loss to be minimized 66 | learning_rate: (float) learning rate for the optimizer 67 | weight_decay: (float) Weight decay for the optimizer. 68 | clip_norm: (float) Global norm to clip. 69 | """ 70 | with tf.name_scope(constants.ADAM_W_OPTIMIZER): 71 | optimizer = adam_w.AdamW(weight_decay, learning_rate=learning_rate) 72 | return generic_optimizer_fn(optimizer, loss, clip_norm) 73 | 74 | 75 | def sgd_optimizer_fn(loss, learning_rate, momentum, clip_norm, use_nesterov): 76 | """Returns the optimizer operation to minimize the loss with SGD with 77 | momentum. 78 | 79 | Args: 80 | loss: (tensor) loss to be minimized 81 | learning_rate: (float) learning rate for the optimizer 82 | momentum: (Optional, float) momentum for the optimizer. 83 | clip_norm: (float) Global norm to clip. 84 | use_nesterov: (bool) whether to use 85 | Nesterov momentum instead of regular momentum. 86 | """ 87 | with tf.name_scope(constants.SGD_OPTIMIZER): 88 | optimizer = tf.train.MomentumOptimizer( 89 | learning_rate, momentum, use_nesterov=use_nesterov 90 | ) 91 | return generic_optimizer_fn(optimizer, loss, clip_norm) 92 | 93 | 94 | def rmsprop_optimizer_fn(loss, learning_rate, momentum, clip_norm): 95 | """Returns the optimizer operation to minimize the loss with RMSProp 96 | 97 | Args: 98 | loss: (tensor) loss to be minimized 99 | learning_rate: (float) learning rate for the optimizer 100 | momentum: (Optional, float) momentum for the optimizer. 101 | clip_norm: (float) Global norm to clip. 102 | """ 103 | with tf.name_scope(constants.RMSPROP_OPTIMIZER): 104 | optimizer = tf.train.RMSPropOptimizer(learning_rate, momentum=momentum) 105 | return generic_optimizer_fn(optimizer, loss, clip_norm) 106 | -------------------------------------------------------------------------------- /sleeprnn/nn/wave_augment.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from scipy.signal import firwin 7 | import tensorflow as tf 8 | 9 | 10 | def apply_fir_filter_tf(signal, kernel): 11 | """For single signal, not batch""" 12 | signal = tf.reshape(signal, shape=[1, 1, -1, 1]) 13 | kernel = tf.reshape(kernel, shape=[1, -1, 1, 1]) 14 | with tf.device("/cpu:0"): 15 | new_signal = tf.nn.conv2d( 16 | input=signal, filter=kernel, strides=[1, 1, 1, 1], padding="SAME" 17 | ) 18 | new_signal = new_signal[0, 0, :, 0] 19 | return new_signal 20 | 21 | 22 | def random_window_tf(signal_size, window_min_size, window_max_size): 23 | window_size = tf.random.uniform([], minval=window_min_size, maxval=window_max_size) 24 | start_sample = tf.random.uniform( 25 | [], minval=0, maxval=(signal_size - window_size - 1) 26 | ) 27 | k_array = np.arange(signal_size) 28 | offset_1 = start_sample + 0.1 * window_size 29 | offset_2 = start_sample + 0.9 * window_size 30 | scaling = 0.1 * window_size / 4 31 | window_onset = tf.math.sigmoid((k_array - offset_1) / scaling) 32 | window_offset = tf.math.sigmoid((k_array - offset_2) / scaling) 33 | window = window_onset - window_offset 34 | return window 35 | 36 | 37 | def random_smooth_function_tf( 38 | signal_size, function_min_val, function_max_val, lp_filter_size 39 | ): 40 | lp_filter = np.hanning(lp_filter_size).astype(np.float32) 41 | lp_filter /= lp_filter.sum() 42 | noise_vector = tf.random.uniform([signal_size], minval=-1, maxval=1) 43 | noise_vector = apply_fir_filter_tf(noise_vector, lp_filter) 44 | # Set noise to [0, 1] range 45 | min_val = tf.reduce_min(noise_vector) 46 | max_val = tf.reduce_max(noise_vector) 47 | noise_vector = (noise_vector - min_val) / (max_val - min_val) 48 | # Set to [function_min_val, function_max_val] range 49 | noise_vector = function_min_val + noise_vector * ( 50 | function_max_val - function_min_val 51 | ) 52 | return noise_vector 53 | 54 | 55 | def lowpass_tf(signal, fs, cutoff, filter_duration_ref=6, wave_expansion_factor=0.5): 56 | numtaps = fs * filter_duration_ref / (cutoff**wave_expansion_factor) 57 | numtaps = int(2 * (numtaps // 2) + 1) # ensure odd numtaps 58 | lp_kernel = firwin(numtaps, cutoff=cutoff, window="hann", fs=fs).astype(np.float32) 59 | lp_kernel /= lp_kernel.sum() 60 | new_signal = apply_fir_filter_tf(signal, lp_kernel) 61 | return new_signal 62 | 63 | 64 | def highpass_tf(signal, fs, cutoff, filter_duration_ref=6, wave_expansion_factor=0.5): 65 | numtaps = fs * filter_duration_ref / (cutoff**wave_expansion_factor) 66 | numtaps = int(2 * (numtaps // 2) + 1) # ensure odd numtaps 67 | lp_kernel = firwin(numtaps, cutoff=cutoff, window="hann", fs=fs).astype(np.float32) 68 | lp_kernel /= lp_kernel.sum() 69 | # HP = delta - LP 70 | hp_kernel = -lp_kernel 71 | hp_kernel[numtaps // 2] += 1 72 | new_signal = apply_fir_filter_tf(signal, hp_kernel) 73 | return new_signal 74 | 75 | 76 | def bandpass_tf( 77 | signal, fs, lowcut, highcut, filter_duration_ref=6, wave_expansion_factor=0.5 78 | ): 79 | new_signal = signal 80 | if lowcut is not None: 81 | new_signal = highpass_tf( 82 | new_signal, fs, lowcut, filter_duration_ref, wave_expansion_factor 83 | ) 84 | if highcut is not None: 85 | new_signal = lowpass_tf( 86 | new_signal, fs, highcut, filter_duration_ref, wave_expansion_factor 87 | ) 88 | return new_signal 89 | 90 | 91 | def generate_soft_mask_from_labels_tf( 92 | labels, fs, mask_lp_filter_duration=0.2, use_background=True 93 | ): 94 | lp_filter_size = int(fs * mask_lp_filter_duration) 95 | labels = tf.cast(labels, tf.float32) 96 | # Enlarge labels 97 | expand_filter = np.ones(lp_filter_size).astype(np.float32) 98 | expanded_labels = apply_fir_filter_tf(labels, expand_filter) 99 | expanded_labels = tf.clip_by_value(expanded_labels, 0, 1) 100 | # Now filter 101 | lp_filter = np.hanning(lp_filter_size).astype(np.float32) 102 | lp_filter /= lp_filter.sum() 103 | smooth_labels = apply_fir_filter_tf(expanded_labels, lp_filter) 104 | if use_background: 105 | soft_mask = 1 - smooth_labels 106 | else: 107 | soft_mask = smooth_labels 108 | return soft_mask 109 | 110 | 111 | def generate_wave_tf( 112 | signal_size, # Number of samples 113 | fs, # [Hz] 114 | max_amplitude, # signal units 115 | min_frequency, # [Hz] 116 | max_frequency, # [Hz] 117 | frequency_bandwidth, # [Hz] 118 | min_duration, # [s] 119 | max_duration, # [s] 120 | mask, # [0, 1] 121 | frequency_lp_filter_duration=0.5, # [s] 122 | amplitude_lp_filter_duration=0.5, # [s] 123 | return_intermediate_steps=False, 124 | ): 125 | # This is ok to be numpy 126 | window_min_size = int(fs * min_duration) 127 | window_max_size = int(fs * max_duration) 128 | frequency_lp_filter_size = int(fs * frequency_lp_filter_duration) 129 | amplitude_lp_filter_size = int(fs * amplitude_lp_filter_duration) 130 | # Oscillation 131 | lower_freq = tf.random.uniform( 132 | [], minval=min_frequency, maxval=max_frequency - frequency_bandwidth 133 | ) 134 | upper_freq = lower_freq + frequency_bandwidth 135 | wave_freq = random_smooth_function_tf( 136 | signal_size, lower_freq, upper_freq, frequency_lp_filter_size 137 | ) 138 | wave_phase = 2 * np.pi * tf.math.cumsum(wave_freq) / fs 139 | oscillation = tf.math.cos(wave_phase) 140 | # Amplitude 141 | amplitude_high = tf.random.uniform([], minval=0, maxval=max_amplitude) 142 | amplitude_low = tf.random.uniform([], minval=0, maxval=amplitude_high) 143 | amplitude = random_smooth_function_tf( 144 | signal_size, amplitude_low, amplitude_high, amplitude_lp_filter_size 145 | ) 146 | # Window 147 | window = random_window_tf(signal_size, window_min_size, window_max_size) 148 | # Total wave 149 | generated_wave = window * amplitude * oscillation 150 | # Optional masking 151 | if mask is not None: 152 | generated_wave = generated_wave * mask 153 | if return_intermediate_steps: 154 | intermediate_steps = { 155 | "oscillation": oscillation, 156 | "amplitude": amplitude, 157 | "window": window, 158 | "mask": mask, 159 | } 160 | return generated_wave, intermediate_steps 161 | return generated_wave 162 | 163 | 164 | def generate_anti_wave_tf( 165 | signal, 166 | signal_size, # number of samples 167 | fs, # [Hz] 168 | lowcut, # [Hz] 169 | highcut, # [Hz] 170 | min_duration, # [s] 171 | max_duration, # [s] 172 | max_attenuation, # [0, 1] 173 | mask, # [0, 1] 174 | return_intermediate_steps=False, 175 | ): 176 | # This is ok to be numpy 177 | window_min_size = int(fs * min_duration) 178 | window_max_size = int(fs * max_duration) 179 | # Oscillation (opposite sign of band signal) and attenuation factor 180 | oscillation = -bandpass_tf(signal, fs, lowcut, highcut) 181 | attenuation_factor = tf.random.uniform([], minval=0, maxval=max_attenuation) 182 | # Window 183 | window = random_window_tf(signal_size, window_min_size, window_max_size) 184 | # Total wave 185 | generated_wave = window * attenuation_factor * oscillation 186 | # Optional masking 187 | if mask is not None: 188 | generated_wave = generated_wave * mask 189 | if return_intermediate_steps: 190 | intermediate_steps = { 191 | "oscillation": -oscillation, 192 | "attenuation": -attenuation_factor, 193 | "window": window, 194 | "mask": mask, 195 | } 196 | return generated_wave, intermediate_steps 197 | return generated_wave 198 | 199 | 200 | def generate_base_oscillation( 201 | signal_size, # Number of samples 202 | fs, # [Hz] 203 | min_frequency, # [Hz] 204 | max_frequency, # [Hz] 205 | frequency_variation_width, # [Hz] 206 | min_amplitude, # signal units 207 | max_amplitude, # signal units 208 | amplitude_relative_variation_width, # relative 209 | frequency_lp_filter_duration=0.5, # [s] 210 | amplitude_lp_filter_duration=0.5, # [s] 211 | ): 212 | frequency_lp_filter_size = int(fs * frequency_lp_filter_duration) 213 | amplitude_lp_filter_size = int(fs * amplitude_lp_filter_duration) 214 | # Oscillation 215 | central_freq = tf.random.uniform([], minval=min_frequency, maxval=max_frequency) 216 | lower_freq = central_freq - 0.5 * frequency_variation_width 217 | upper_freq = central_freq + 0.5 * frequency_variation_width 218 | wave_freq = random_smooth_function_tf( 219 | signal_size, lower_freq, upper_freq, frequency_lp_filter_size 220 | ) 221 | wave_phase = 2 * np.pi * tf.math.cumsum(wave_freq) / fs 222 | oscillation = tf.math.cos(wave_phase) 223 | # Amplitude 224 | central_amplitude = tf.random.uniform( 225 | [], minval=min_amplitude, maxval=max_amplitude 226 | ) 227 | amplitude_high = central_amplitude * (1 + 0.5 * amplitude_relative_variation_width) 228 | amplitude_low = central_amplitude * (1 - 0.5 * amplitude_relative_variation_width) 229 | amplitude = random_smooth_function_tf( 230 | signal_size, amplitude_low, amplitude_high, amplitude_lp_filter_size 231 | ) 232 | return amplitude * oscillation, central_amplitude, central_freq 233 | 234 | 235 | def generate_false_spindle_single_contamination( 236 | signal, 237 | signal_size, # Number of samples 238 | fs, # [Hz] 239 | duration_range, # [s] 240 | bandstop_cutoff, # [Hz] 241 | spindle_frequency_range, # [Hz] 242 | spindle_frequency_variation_width, # [Hz] 243 | spindle_amplitude_absolute_range, # signal units 244 | spindle_amplitude_relative_variation_width, 245 | contamination_frequency_range, # [Hz] IMPORTANT 246 | contamination_frequency_variation_width, # [Hz] 247 | contamination_amplitude_relative_range, # IMPORTANT 248 | contamination_amplitude_relative_variation_width, 249 | mask, 250 | min_distance_between_frequencies=1.5, # Hz 251 | frequency_lp_filter_duration=0.5, # [s] 252 | amplitude_lp_filter_duration=0.5, # [s] 253 | ): 254 | 255 | if ( 256 | contamination_frequency_range[0] 257 | > spindle_frequency_range[0] - min_distance_between_frequencies 258 | ): 259 | raise ValueError( 260 | "Contamination interval %s Hz incompatible with spindle interval %s Hz and min distance %s Hz" 261 | % ( 262 | contamination_frequency_range, 263 | spindle_frequency_range, 264 | min_distance_between_frequencies, 265 | ) 266 | ) 267 | 268 | # Prepare window 269 | window_min_size = int(fs * duration_range[0]) 270 | window_max_size = int(fs * duration_range[1]) 271 | window = random_window_tf(signal_size, window_min_size, window_max_size) 272 | window = mask * window if (mask is not None) else window 273 | 274 | part_to_remove = bandpass_tf(signal, fs, bandstop_cutoff[0], bandstop_cutoff[1]) 275 | 276 | base_sigma_wave, sigma_central_amp, sigma_central_freq = generate_base_oscillation( 277 | signal_size, 278 | fs, 279 | spindle_frequency_range[0], 280 | spindle_frequency_range[1], 281 | spindle_frequency_variation_width, 282 | spindle_amplitude_absolute_range[0], 283 | spindle_amplitude_absolute_range[1], 284 | spindle_amplitude_relative_variation_width, 285 | frequency_lp_filter_duration, 286 | amplitude_lp_filter_duration, 287 | ) 288 | 289 | contamination_upper_freq = tf.math.minimum( 290 | float(contamination_frequency_range[1]), 291 | sigma_central_freq - min_distance_between_frequencies, 292 | ) 293 | base_contamination_wave, _, _ = generate_base_oscillation( 294 | signal_size, 295 | fs, 296 | contamination_frequency_range[0], 297 | contamination_upper_freq, 298 | contamination_frequency_variation_width, 299 | contamination_amplitude_relative_range[0] * sigma_central_amp, 300 | contamination_amplitude_relative_range[1] * sigma_central_amp, 301 | contamination_amplitude_relative_variation_width, 302 | frequency_lp_filter_duration, 303 | amplitude_lp_filter_duration, 304 | ) 305 | 306 | # Total wave 307 | generated_wave = window * ( 308 | -part_to_remove + base_sigma_wave + base_contamination_wave 309 | ) 310 | return generated_wave 311 | --------------------------------------------------------------------------------