├── dosed ├── __init__.py ├── models │ ├── __init__.py │ ├── dosed2.py │ ├── dosed1.py │ ├── dosed3.py │ └── base.py ├── trainers │ ├── __init__.py │ ├── base_adam.py │ └── base.py ├── datasets │ ├── __init__.py │ ├── utils.py │ └── dataset.py ├── utils │ ├── binary_to_array.py │ ├── decode.py │ ├── misc.py │ ├── colorize.py │ ├── __init__.py │ ├── encode.py │ ├── non_maximum_suppression.py │ ├── jaccard_overlap.py │ ├── data_from_h5.py │ ├── match_events_to_default_localizations.py │ └── logger.py ├── preprocessing │ ├── __init__.py │ ├── normalizers.py │ └── regularization.py └── functions │ ├── __init__.py │ ├── focal_loss.py │ ├── random_negative_mining_loss.py │ ├── worst_negative_mining_loss.py │ ├── detection.py │ ├── compute_metrics_dataset.py │ ├── metrics.py │ └── simple_loss.py ├── dosed_detection.png ├── tests ├── test_files │ └── h5 │ │ ├── 23c61485-a9a5-478b-8115-f34b883876b8.h5 │ │ └── 9a3b2e69-3e81-42f3-bba2-4afd8c7a0268.h5 ├── test_regularization.py ├── test_encode_decode.py ├── test_nms.py ├── test_trainers.py ├── test_models.py └── test_dataset.py ├── Dockerfile ├── requirements.txt ├── Makefile ├── .circleci └── config.yml ├── setup.py ├── minimum_example ├── download_data.py ├── to_h5.py └── train_and_evaluate_dosed.ipynb ├── LICENCE ├── .gitignore └── README.md /dosed/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.4" 2 | -------------------------------------------------------------------------------- /dosed_detection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dreem-Organization/dosed/HEAD/dosed_detection.png -------------------------------------------------------------------------------- /tests/test_files/h5/23c61485-a9a5-478b-8115-f34b883876b8.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dreem-Organization/dosed/HEAD/tests/test_files/h5/23c61485-a9a5-478b-8115-f34b883876b8.h5 -------------------------------------------------------------------------------- /tests/test_files/h5/9a3b2e69-3e81-42f3-bba2-4afd8c7a0268.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dreem-Organization/dosed/HEAD/tests/test_files/h5/9a3b2e69-3e81-42f3-bba2-4afd8c7a0268.h5 -------------------------------------------------------------------------------- /dosed/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dosed1 import DOSED1 2 | from .dosed2 import DOSED2 3 | from .dosed3 import DOSED3 4 | 5 | __all__ = [ 6 | "DOSED1", 7 | "DOSED2", 8 | "DOSED3", 9 | ] 10 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM floydhub/pytorch:1.0.0-gpu.cuda9cudnn7-py3.38 2 | 3 | RUN pip install pip --upgrade 4 | COPY requirements.txt /requirements.txt 5 | RUN pip install -r /requirements.txt 6 | 7 | WORKDIR /workspace -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29 2 | h5py==2.8.0 3 | tqdm==4.28.1 4 | boto3==1.9.31 5 | joblib==0.12.5 6 | matplotlib==3.0.2 7 | pyedflib==0.1.14 8 | torch==1.0.0 9 | pytest==3.10.0 10 | pytest-cov==2.6.0 -------------------------------------------------------------------------------- /dosed/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import TrainerBase 2 | from .base_adam import TrainerBaseAdam 3 | 4 | __all__ = [ 5 | "TrainerBase", 6 | "TrainerBaseAdam", 7 | ] 8 | 9 | trainers = { 10 | "basic": TrainerBase, 11 | "adam": TrainerBaseAdam, 12 | } 13 | -------------------------------------------------------------------------------- /dosed/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import EventDataset, BalancedEventDataset 2 | from .utils import collate, get_train_validation_test 3 | 4 | __all__ = [ 5 | "EventDataset", 6 | "collate", 7 | "BalancedEventDataset", 8 | "get_train_validation_test" 9 | ] 10 | -------------------------------------------------------------------------------- /tests/test_regularization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dosed.preprocessing import GaussianNoise, RescaleNormal, Invert 4 | from dosed.utils import Compose 5 | 6 | 7 | def test_regularization(): 8 | x = torch.rand(32, 25, 25) 9 | 10 | regularizer = Compose( 11 | [GaussianNoise(), RescaleNormal(), Invert()] 12 | ) 13 | regularizer(x) 14 | -------------------------------------------------------------------------------- /dosed/utils/binary_to_array.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def binary_to_array(x): 5 | """ Return [start, duration] from binary array 6 | 7 | binary_to_array([0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1]) 8 | [[4, 8], [11, 13]] 9 | """ 10 | tmp = np.array([0] + list(x) + [0]) 11 | return np.where((tmp[1:] - tmp[:-1]) != 0)[0].reshape((-1, 2)).tolist() 12 | -------------------------------------------------------------------------------- /dosed/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .regularization import GaussianNoise, RescaleNormal, Invert 2 | from .normalizers import clip, clip_and_normalize, mask_clip_and_normalize 3 | 4 | 5 | normalizers = { 6 | "clip": clip, 7 | "clip_and_normalize": clip_and_normalize, 8 | "mask_clip_and_normalize": mask_clip_and_normalize 9 | } 10 | 11 | 12 | __all__ = [ 13 | GaussianNoise, 14 | RescaleNormal, 15 | Invert, 16 | ] 17 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | DOWNLOAD_PATH ?= ./data 2 | 3 | test: 4 | python -m pytest --cov-config .coveragerc --cov=./dosed --cov-fail-under=90 --cov-report=term-missing 5 | 6 | download_example: 7 | @echo Dowloading Data and converting to H5 in $(DOWNLOAD_PATH) 8 | python minimum_example/download_data.py $(DOWNLOAD_PATH)/downloads/ 9 | python minimum_example/to_h5.py $(DOWNLOAD_PATH)/downloads/ $(DOWNLOAD_PATH)/h5/ 10 | 11 | start_docker: 12 | docker build -t lol . 13 | docker run --runtime=nvidia -it -v "${PWD}:/workspace" lol bash 14 | -------------------------------------------------------------------------------- /dosed/utils/decode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def decode(localization, localizations_default): 5 | """Opposite of encode""" 6 | center_encoded, width_encoded = localization[:, 0], localization[:, 1] 7 | x_plus_y = (center_encoded * localizations_default[:, 1] + localizations_default[:, 0]) * 2 8 | y_minus_x = torch.exp(width_encoded) * localizations_default[:, 1] 9 | x = (x_plus_y - y_minus_x) / 2 10 | y = (x_plus_y + y_minus_x) / 2 11 | 12 | localization_decoded = torch.cat([x.unsqueeze(1), y.unsqueeze(1)], 1) 13 | return localization_decoded 14 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build_and_test: 4 | docker: 5 | - image: floydhub/pytorch:1.0.0-gpu.cuda9cudnn7-py3.38 6 | steps: 7 | - checkout # checkout source code to working directory 8 | - run: 9 | name: Install requirements 10 | command: | # use pip to install dependencies and then launch tests 11 | pip install pip --upgrade 12 | pip install -r requirements.txt 13 | - run: 14 | name: Launch Test 15 | command: make test 16 | workflows: 17 | version: 2 18 | build_and_test: 19 | jobs: 20 | - build_and_test 21 | -------------------------------------------------------------------------------- /dosed/utils/misc.py: -------------------------------------------------------------------------------- 1 | class Compose(object): 2 | """ From torchvision, with love.""" 3 | 4 | def __init__(self, transformations): 5 | self.transformations = transformations 6 | 7 | def __call__(self, x): 8 | for transformation in self.transformations: 9 | x = transformation(x) 10 | return x 11 | 12 | def __repr__(self): 13 | format_string = self.__class__.__name__ + '(' 14 | for t in self.transformations: 15 | format_string += '\n' 16 | format_string += ' {0}'.format(t) 17 | format_string += '\n)' 18 | return format_string 19 | -------------------------------------------------------------------------------- /dosed/utils/colorize.py: -------------------------------------------------------------------------------- 1 | """ Colorize text displayed on terminal. """ 2 | 3 | color2num = dict( 4 | gray=30, 5 | red=31, 6 | green=32, 7 | yellow=33, 8 | blue=34, 9 | magenta=35, 10 | cyan=36, 11 | white=37, 12 | crimson=38 13 | ) 14 | 15 | 16 | def colorize(string, color, bold=False, highlight=False): 17 | """ 18 | Colorize a string. 19 | 20 | This function was originally written by John Schulman. 21 | """ 22 | attr = [] 23 | num = color2num[color] 24 | if highlight: 25 | num += 10 26 | attr.append(str(num)) 27 | if bold: 28 | attr.append('1') 29 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string) 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import dosed 3 | 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | name="dosed", 9 | version=dosed.__version__, 10 | author="Dreem", 11 | author_email=" ", 12 | description="Implementation of DOSED algorithm for event detection", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/Dreem-Organization/dosed/", 16 | packages=setuptools.find_packages(), 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ], 22 | ) -------------------------------------------------------------------------------- /tests/test_encode_decode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dosed.utils import encode, decode 4 | 5 | 6 | def test_encode_decode(): 7 | localizations_default = torch.FloatTensor([1]).repeat([4, 2]).uniform_() 8 | 9 | for line in range(localizations_default.shape[0]): 10 | if localizations_default[line, 0] > localizations_default[line, 1]: 11 | aux = localizations_default[line, 0] 12 | localizations_default[line, 0] = localizations_default[line, 1] 13 | localizations_default[line, 1] = aux 14 | 15 | localizations = encode(localizations_default, localizations_default) 16 | is_it_retrieved = decode(localizations, localizations_default) 17 | 18 | assert ((localizations_default - is_it_retrieved).numpy() > 1e-4).sum() == 0 19 | -------------------------------------------------------------------------------- /minimum_example/download_data.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore import UNSIGNED 3 | from botocore.client import Config 4 | import tqdm 5 | import os 6 | import sys 7 | 8 | download_directory = sys.argv[1] 9 | if not os.path.isdir(download_directory): 10 | os.makedirs(download_directory) 11 | 12 | bucket_name = 'dreem-dosed-minimum-example' 13 | 14 | client = boto3.client('s3', config=Config(signature_version=UNSIGNED)) 15 | 16 | bucket_objects = client.list_objects(Bucket='dreem-dosed-minimum-example')["Contents"] 17 | print("\n Downloading EDF files and annotations from S3") 18 | for bucket_object in tqdm.tqdm(bucket_objects): 19 | filename = bucket_object["Key"] 20 | client.download_file( 21 | Bucket=bucket_name, 22 | Key=filename, 23 | Filename=download_directory + "/{}".format(filename) 24 | ) 25 | -------------------------------------------------------------------------------- /dosed/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .encode import encode 2 | from .decode import decode 3 | from .jaccard_overlap import jaccard_overlap 4 | from .non_maximum_suppression import non_maximum_suppression 5 | from .match_events_to_default_localizations import match_events_localization_to_default_localizations 6 | from .binary_to_array import binary_to_array 7 | from .misc import Compose 8 | from .logger import Logger 9 | from .data_from_h5 import get_h5_data, get_h5_events 10 | from .colorize import colorize 11 | 12 | __all__ = [ 13 | "encode", 14 | "decode", 15 | "jaccard_overlap", 16 | "non_maximum_suppression", 17 | "match_events_localization_to_default_localizations", 18 | "binary_to_array", 19 | "Compose", 20 | "get_h5_data", 21 | "get_h5_events", 22 | "Logger", 23 | "adjust_lr" 24 | "colorize", 25 | ] 26 | -------------------------------------------------------------------------------- /tests/test_nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dosed.utils import non_maximum_suppression 4 | 5 | 6 | def test_non_maximum_suppression(): 7 | localizations_scores = torch.FloatTensor( 8 | [ 9 | [20, 50, 0.99], # 0 10 | [10, 50, 0.97], # 1 11 | [25, 42, 0.6], # 2 12 | [30, 60, 0.98], # 3 13 | 14 | [75, 85, 0.92], # 4 15 | [72, 87, 0.90], # 5 16 | [76, 85, 0.78], # 6 17 | [80, 90, 0.91], # 7 18 | ] 19 | ) 20 | localizations = localizations_scores[:, :2] / 100 21 | scores = localizations_scores[:, -1] 22 | overlap = 0.4 23 | kept = [tuple([int(x * 100) for x in y]) 24 | for y in non_maximum_suppression(localizations, scores, overlap)] 25 | to_keep = [(20, 50), (75, 85), (80, 90)] 26 | assert set(kept) == set(to_keep) 27 | -------------------------------------------------------------------------------- /dosed/utils/encode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def encode(localization_match, localizations_default): 5 | """localization_match are converted relatively to their default location 6 | 7 | localization_match has size [batch, number_of_localizations, 2] containing the ground truth 8 | matched localization (representation x y) 9 | localization_defaults has size [number_of_localizations, 2] 10 | 11 | returns localization_target [batch, number_of_localizations, 2] 12 | """ 13 | center = (localization_match[:, 0] + localization_match[:, 1]) / 2 - localizations_default[:, 0] 14 | center = center / localizations_default[:, 1] 15 | width = torch.log((localization_match[:, 1] - localization_match[:, 0]) / localizations_default[:, 1]) 16 | localization_target = torch.cat([center.unsqueeze(1), width.unsqueeze(1)], 1) 17 | return localization_target 18 | -------------------------------------------------------------------------------- /dosed/utils/non_maximum_suppression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def non_maximum_suppression(localizations, scores, overlap=0.5): 5 | """1D nms""" 6 | x = localizations[:, 0] 7 | y = localizations[:, 1] 8 | 9 | areas = y - x 10 | order = scores.sort(0, descending=True)[1] 11 | keep = [] 12 | while order.numel() > 1: 13 | i = order[0] 14 | keep.append([x[i], y[i]]) 15 | order = order[1:] 16 | xx = torch.clamp(x[order], min=x[i].item()) 17 | yy = torch.clamp(y[order], max=y[i].item()) 18 | 19 | intersection = torch.clamp(yy - xx, min=0) 20 | 21 | intersection_over_union = intersection / (areas[i] + areas[order] - intersection) 22 | 23 | order = order[intersection_over_union <= overlap] 24 | 25 | keep.extend([[x[k], y[k]] for k in order]) # remaining element if order has size 1 26 | 27 | return keep 28 | -------------------------------------------------------------------------------- /dosed/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .simple_loss import DOSEDSimpleLoss 2 | from .worst_negative_mining_loss import DOSEDWorstNegativeMiningLoss 3 | from .random_negative_mining_loss import DOSEDRandomNegativeMiningLoss 4 | from .focal_loss import DOSEDFocalLoss 5 | from .detection import Detection 6 | from .metrics import precision_function, recall_function, f1_function 7 | from .compute_metrics_dataset import compute_metrics_dataset 8 | 9 | loss_functions = { 10 | "simple": DOSEDSimpleLoss, 11 | "worst_negative_mining": DOSEDWorstNegativeMiningLoss, 12 | "focal": DOSEDFocalLoss, 13 | "random_negative_mining": DOSEDRandomNegativeMiningLoss, 14 | 15 | } 16 | 17 | available_score_functions = { 18 | "precision": precision_function(), 19 | "recall": recall_function(), 20 | "f1": f1_function(), 21 | } 22 | 23 | __all__ = [ 24 | "loss_functions" 25 | "Detection", 26 | "available_score_functions", 27 | "compute_metrics_dataset", 28 | ] 29 | -------------------------------------------------------------------------------- /dosed/utils/jaccard_overlap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def jaccard_overlap(localizations_a, localizations_b): 5 | """Jaccard overlap between two segments A ∩ B / (LENGTH_A + LENGTH_B - A ∩ B) 6 | 7 | localizations_a: tensor of localizations 8 | localizations_a: tensor of localizations 9 | """ 10 | A = localizations_a.size(0) 11 | B = localizations_b.size(0) 12 | # intersection 13 | max_min = torch.max(localizations_a[:, 0].unsqueeze(1).expand(A, B), 14 | localizations_b[:, 0].unsqueeze(0).expand(A, B)) 15 | min_max = torch.min(localizations_a[:, 1].unsqueeze(1).expand(A, B), 16 | localizations_b[:, 1].unsqueeze(0).expand(A, B)) 17 | intersection = torch.clamp((min_max - max_min), min=0) 18 | lentgh_a = (localizations_a[:, 1] - localizations_a[:, 0]).unsqueeze(1).expand(A, B) 19 | lentgh_b = (localizations_b[:, 1] - localizations_b[:, 0]).unsqueeze(0).expand(A, B) 20 | overlaps = intersection / (lentgh_a + lentgh_b - intersection) 21 | return overlaps 22 | -------------------------------------------------------------------------------- /dosed/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | 4 | import torch 5 | 6 | 7 | def collate(batch): 8 | """collate fn because unconsistent number of events""" 9 | batch_events = [] 10 | batch_eegs = [] 11 | for eeg, events in batch: 12 | batch_eegs.append(eeg) 13 | batch_events.append(events) 14 | return torch.stack(batch_eegs, 0), batch_events 15 | 16 | 17 | def get_train_validation_test(h5_directory, 18 | percent_test, 19 | percent_validation, 20 | seed=None): 21 | 22 | records = [x for x in os.listdir(h5_directory) if (x != ".cache" and x[-2:] == "h5")] 23 | 24 | random.seed(seed) 25 | index_test = int(len(records) * percent_test / 100) 26 | random.shuffle(records) 27 | test = records[:index_test] 28 | records_train = records[index_test:] 29 | 30 | index_validation = int(len(records_train) * percent_validation / 100) 31 | random.shuffle(records_train) 32 | validation = records_train[:index_validation] 33 | train = records_train[index_validation:] 34 | 35 | return train, validation, test 36 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Dreem (Valentin Thorey) 4 | Copyright (c) 2018 Dreem (Stanislas Chambon) 5 | Copyright (c) 2018 Dreem (Albert Bou) 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. -------------------------------------------------------------------------------- /dosed/functions/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .simple_loss import DOSEDSimpleLoss 5 | 6 | 7 | class DOSEDFocalLoss(DOSEDSimpleLoss): 8 | """Loss function inspired from https://github.com/amdegroot/ssd.pytorch""" 9 | 10 | def __init__(self, 11 | number_of_classes, 12 | device, 13 | alpha=0.25, 14 | gamma=2, 15 | ): 16 | super(DOSEDFocalLoss, self).__init__( 17 | number_of_classes=number_of_classes, 18 | device=device) 19 | self.device = device 20 | self.number_of_classes = number_of_classes + 1 # eventlessness 21 | self.alpha = alpha 22 | self.gamma = gamma 23 | 24 | def get_classification_loss(self, index, classifications, 25 | classifications_target): 26 | index_expanded = index.unsqueeze(2).expand_as(classifications) 27 | 28 | cross_entropy = F.cross_entropy( 29 | classifications[index_expanded.gt(0) 30 | ].view(-1, self.number_of_classes), 31 | classifications_target[index.gt(0)], 32 | reduction="none", 33 | ) 34 | pt = torch.exp(-cross_entropy) 35 | loss_classification = ( 36 | self.alpha * ((1 - pt) ** self.gamma) * cross_entropy).sum() 37 | return loss_classification 38 | -------------------------------------------------------------------------------- /dosed/utils/data_from_h5.py: -------------------------------------------------------------------------------- 1 | """Transform a folder with h5 files into a dataset for dosed""" 2 | 3 | import numpy as np 4 | 5 | import h5py 6 | 7 | from ..preprocessing import normalizers 8 | from scipy.interpolate import interp1d 9 | 10 | 11 | def get_h5_data(filename, signals, fs): 12 | with h5py.File(filename, "r") as h5: 13 | 14 | signal_size = int(fs * min( 15 | set([h5[signal["h5_path"]].size / signal['fs'] for signal in signals]) 16 | )) 17 | 18 | t_target = np.cumsum([1 / fs] * signal_size) 19 | data = np.zeros((len(signals), signal_size)) 20 | for i, signal in enumerate(signals): 21 | t_source = np.cumsum([1 / signal["fs"]] * 22 | h5[signal["h5_path"]].size) 23 | normalizer = normalizers[signal['processing']["type"]](**signal['processing']['args']) 24 | data[i, :] = interp1d(t_source, normalizer(h5[signal["h5_path"]][:]), 25 | fill_value="extrapolate")(t_target) 26 | return data 27 | 28 | 29 | def get_h5_events(filename, event, fs): 30 | with h5py.File(filename, "r") as h5: 31 | starts = h5[event["h5_path"]]["start"][:] 32 | durations = h5[event["h5_path"]]["duration"][:] 33 | assert len(starts) == len(durations), "Inconsistents event durations and starts" 34 | 35 | data = np.zeros((2, len(starts))) 36 | data[0, :] = starts * fs 37 | data[1, :] = durations * fs 38 | return data 39 | -------------------------------------------------------------------------------- /dosed/preprocessing/normalizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def clip(max_value): 5 | """returns a function to clip data""" 6 | 7 | def clipper(signal_data, max_value=max_value): 8 | """returns input signal clipped between +/- max_value. 9 | """ 10 | return np.clip(signal_data, -max_value, max_value) 11 | 12 | return clipper 13 | 14 | 15 | def clip_and_normalize(min_value, max_value): 16 | """returns a function to clip and normalize data""" 17 | 18 | def clipper(x, min_value=min_value, max_value=max_value): 19 | """returns input signal clipped between min_value and max_value 20 | and then normalized between -0.5 and 0.5. 21 | """ 22 | x = np.clip(x, min_value, max_value) 23 | x = ((x - min_value) / 24 | (max_value - min_value)) - 0.5 25 | return x 26 | 27 | return clipper 28 | 29 | 30 | def mask_clip_and_normalize(min_value, max_value, mask_value): 31 | """returns a function to clip and normalize data""" 32 | 33 | def clipper(x, min_value=min_value, max_value=max_value, 34 | mask_value=mask_value): 35 | """returns input signal clipped between min_value and max_value 36 | and then normalized between -0.5 and 0.5. 37 | """ 38 | mask = np.ma.masked_equal(x, mask_value) 39 | x = np.clip(x, min_value, max_value) 40 | x = ((x - min_value) / 41 | (max_value - min_value)) - 0.5 42 | x[mask.mask] = mask_value 43 | return x 44 | 45 | return clipper 46 | -------------------------------------------------------------------------------- /minimum_example/to_h5.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | 5 | import h5py 6 | import pyedflib 7 | import tqdm 8 | 9 | print("\n Converting EDF and annotations to standard H5 file") 10 | download_directory = sys.argv[1] 11 | h5_directory = sys.argv[2] 12 | if not os.path.isdir(h5_directory): 13 | os.makedirs(h5_directory) 14 | 15 | records = [ 16 | x.split(".")[0] for x in os.listdir(download_directory) if x[-3:] == "edf" 17 | ] 18 | 19 | for record in tqdm.tqdm(records): 20 | edf_filename = download_directory + record + ".edf" 21 | spindle_filename = download_directory + record + "_spindle.json" 22 | h5_filename = '{}/{}.h5'.format(h5_directory, record) 23 | 24 | with h5py.File(h5_filename, 'w') as h5: 25 | 26 | # Taking care of spindle annotations 27 | spindles = [ 28 | (x["start"], x["end"] - x["start"]) for x in json.load(open(spindle_filename)) 29 | ] 30 | starts, durations = list(zip(*spindles)) 31 | h5.create_group("spindle") 32 | h5.create_dataset("spindle/start", data=starts) 33 | h5.create_dataset("spindle/duration", data=durations) 34 | 35 | # Extract signals 36 | with pyedflib.EdfReader(edf_filename) as f: 37 | labels = f.getSignalLabels() 38 | frequencies = f.getSampleFrequencies().astype(int).tolist() 39 | 40 | for i, (label, frequency) in enumerate(zip(labels, frequencies)): 41 | 42 | path = "{}".format(label.lower()) 43 | data = f.readSignal(i) 44 | h5.create_dataset(path, data=data) 45 | h5[path].attrs["fs"] = frequency 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | settings.py 6 | data/ 7 | 8 | .idea/* 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # mac .DS_Store files 108 | .DS_Store 109 | 110 | # pytest 111 | .pytest_cache/ 112 | .coverage* 113 | -------------------------------------------------------------------------------- /dosed/preprocessing/regularization.py: -------------------------------------------------------------------------------- 1 | """ This script contains a set of transformations than can be applied to 2 | input data before feeding it to the model""" 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | class RegularizerBase: 10 | def __init__(self, p): 11 | self.p = p 12 | 13 | def __call__(self, x): 14 | if np.random.rand() < self.p: 15 | return self.call(x) 16 | else: 17 | return x 18 | 19 | def call(self, x): 20 | raise NotImplementedError 21 | 22 | 23 | class GaussianNoise(RegularizerBase): 24 | """Gaussian noise regularizer. 25 | 26 | Args: 27 | sigma (float, optional): relative standard deviation used to generate the 28 | noise. Relative means that it will be multiplied by the magnitude of 29 | the value your are adding the noise to. This means that sigma can be 30 | the same regardless of the scale of the vector. 31 | """ 32 | 33 | def __init__(self, sigma=0.01, p=1): 34 | super(GaussianNoise, self).__init__(p=p) 35 | self.sigma = sigma 36 | self.noise = torch.tensor(0) 37 | 38 | def call(self, x): 39 | if self.sigma != 0: 40 | scale = self.sigma * x 41 | sampled_noise = self.noise.repeat(*x.size()).float().normal_() * scale 42 | x = x + sampled_noise 43 | return x 44 | 45 | 46 | class RescaleNormal(RegularizerBase): 47 | def __init__(self, p=0.5, std=0.01): 48 | super(RescaleNormal, self).__init__(p=p) 49 | self.std = std 50 | 51 | def call(self, x): 52 | factor = np.random.normal(loc=1, scale=self.std) 53 | return x * factor 54 | 55 | 56 | class Invert(RegularizerBase): 57 | def __init__(self, p=0.5): 58 | super(Invert, self).__init__(p=p) 59 | 60 | def call(self, x): 61 | return x * -1 62 | -------------------------------------------------------------------------------- /dosed/functions/random_negative_mining_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | from .simple_loss import DOSEDSimpleLoss 6 | 7 | 8 | class DOSEDRandomNegativeMiningLoss(DOSEDSimpleLoss): 9 | """Loss function inspired from https://github.com/amdegroot/ssd.pytorch""" 10 | 11 | def __init__(self, 12 | number_of_classes, 13 | device, 14 | factor_negative_mining=3, 15 | default_negative_mining=10, 16 | ): 17 | super(DOSEDRandomNegativeMiningLoss, self).__init__( 18 | number_of_classes=number_of_classes, 19 | device=device) 20 | self.factor_negative_mining = factor_negative_mining 21 | self.default_negative_mining = default_negative_mining 22 | 23 | def get_negative_index(self, positive, classifications, 24 | classifications_target): 25 | number_of_default_events = classifications.shape[1] 26 | number_of_positive = positive.long().sum(1) 27 | number_of_negative = torch.clamp( 28 | number_of_positive * self.factor_negative_mining, 29 | min=self.default_negative_mining) 30 | number_of_negative = torch.min( 31 | number_of_negative, (number_of_default_events - number_of_positive)) 32 | 33 | def pick_zero_random_index(tensor, size): 34 | result = torch.zeros_like(tensor) 35 | for index in np.random.choice( 36 | (1 - tensor).nonzero().view(-1), size=size, replace=False): 37 | result[index] = 1 38 | return result 39 | 40 | random_negative_index = [ 41 | pick_zero_random_index(line, int(number_of_negative[i])) 42 | for i, line in enumerate(torch.unbind(positive, dim=0))] 43 | negative = torch.stack(random_negative_index, dim=0) 44 | 45 | return negative 46 | -------------------------------------------------------------------------------- /dosed/functions/worst_negative_mining_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .simple_loss import DOSEDSimpleLoss 5 | 6 | 7 | class DOSEDWorstNegativeMiningLoss(DOSEDSimpleLoss): 8 | """Loss function inspired from https://github.com/amdegroot/ssd.pytorch""" 9 | 10 | def __init__(self, 11 | number_of_classes, 12 | device, 13 | factor_negative_mining=3, 14 | default_negative_mining=10, 15 | ): 16 | super(DOSEDWorstNegativeMiningLoss, self).__init__( 17 | number_of_classes=number_of_classes, 18 | device=device) 19 | self.factor_negative_mining = factor_negative_mining 20 | self.default_negative_mining = default_negative_mining 21 | 22 | def get_negative_index(self, positive, classifications, classifications_target): 23 | batch = classifications.shape[0] 24 | number_of_default_events = classifications.shape[1] 25 | number_of_positive = positive.long().sum(1) 26 | number_of_negative = torch.clamp(number_of_positive * self.factor_negative_mining, 27 | min=self.default_negative_mining) 28 | number_of_negative = torch.min(number_of_negative, 29 | (number_of_default_events - number_of_positive)) 30 | loss_softmax = -torch.log(nn.Softmax(1)( 31 | classifications.view(-1, self.number_of_classes)).gather( 32 | 1, classifications_target.view(-1, 1))).view(batch, -1) 33 | loss_softmax[positive] = 0 34 | _, loss_softmax_descending_index = loss_softmax.sort(1, descending=True) 35 | _, loss_softmax_descending_rank = loss_softmax_descending_index.sort(1) 36 | negative = (loss_softmax_descending_rank < 37 | number_of_negative.unsqueeze(1).expand_as(loss_softmax_descending_rank)) 38 | return negative 39 | -------------------------------------------------------------------------------- /dosed/trainers/base_adam.py: -------------------------------------------------------------------------------- 1 | """ Trainer class with Adam optimizer """ 2 | 3 | from torch import device 4 | import torch.optim as optim 5 | 6 | from .base import TrainerBase 7 | 8 | 9 | class TrainerBaseAdam(TrainerBase): 10 | """ Trainer class with Adam optimizer """ 11 | 12 | def __init__( 13 | self, 14 | net, 15 | optimizer_parameters={ 16 | "lr": 0.001, 17 | "weight_decay": 1e-8, 18 | }, 19 | loss_specs={ 20 | "type": "focal", 21 | "parameters": { 22 | "number_of_classes": 1, 23 | "alpha": 0.25, 24 | "gamma": 2, 25 | "device": device("cuda"), 26 | } 27 | }, 28 | metrics=["precision", "recall", "f1"], 29 | epochs=100, 30 | metric_to_maximize="f1", 31 | patience=None, 32 | save_folder=None, 33 | logger_parameters={ 34 | "num_events": 1, 35 | "output_dir": None, 36 | "output_fname": 'train_history.json', 37 | "metrics": ["precision", "recall", "f1"], 38 | "name_events": ["event_type_1"] 39 | }, 40 | threshold_space={ 41 | "upper_bound": 0.85, 42 | "lower_bound": 0.55, 43 | "num_samples": 10, 44 | "zoom_in": False, 45 | }, 46 | matching_overlap=0.5, 47 | ): 48 | super(TrainerBaseAdam, self).__init__( 49 | net=net, 50 | optimizer_parameters=optimizer_parameters, 51 | loss_specs=loss_specs, 52 | metrics=metrics, 53 | epochs=epochs, 54 | metric_to_maximize=metric_to_maximize, 55 | patience=patience, 56 | save_folder=save_folder, 57 | logger_parameters=logger_parameters, 58 | threshold_space=threshold_space, 59 | matching_overlap=matching_overlap, 60 | ) 61 | self.optimizer = optim.Adam(net.parameters(), **optimizer_parameters) 62 | -------------------------------------------------------------------------------- /dosed/functions/detection.py: -------------------------------------------------------------------------------- 1 | """inspired from https://github.com/amdegroot/ssd.pytorch""" 2 | import torch.nn as nn 3 | 4 | from ..utils import non_maximum_suppression, decode 5 | 6 | 7 | class Detection(nn.Module): 8 | """""" 9 | 10 | def __init__(self, 11 | number_of_classes, 12 | overlap_non_maximum_suppression, 13 | classification_threshold, 14 | ): 15 | super(Detection, self).__init__() 16 | self.number_of_classes = number_of_classes 17 | self.overlap_non_maximum_suppression = overlap_non_maximum_suppression 18 | self.classification_threshold = classification_threshold 19 | 20 | def forward(self, localizations, classifications, localizations_default): 21 | batch = localizations.size(0) 22 | scores = nn.Softmax(dim=2)(classifications) 23 | results = [] 24 | for i in range(batch): 25 | result = [] 26 | localization_decoded = decode(localizations[i], localizations_default) 27 | for class_index in range(1, self.number_of_classes): # we remove class 0 28 | scores_batch_class = scores[i, :, class_index] 29 | scores_batch_class_selected = scores_batch_class[ 30 | scores_batch_class > self.classification_threshold] 31 | if len(scores_batch_class_selected) == 0: 32 | continue 33 | localizations_decoded_selected = localization_decoded[ 34 | (scores_batch_class > self.classification_threshold) 35 | .unsqueeze(1).expand_as(localization_decoded)].view(-1, 2) 36 | 37 | events = non_maximum_suppression( 38 | localizations_decoded_selected, 39 | scores_batch_class_selected, 40 | overlap=self.overlap_non_maximum_suppression, 41 | ) 42 | result.extend([(event[0].item(), event[1].item(), class_index - 1) 43 | for event in events]) 44 | result = [event for event in result if event[0] > -10 and event[1] < 10] 45 | results.append(result) 46 | return results 47 | -------------------------------------------------------------------------------- /dosed/functions/compute_metrics_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .metrics import precision_function, recall_function, f1_function 4 | 5 | available_score_functions = { 6 | "precision": precision_function(), 7 | "recall": recall_function(), 8 | "f1": f1_function(), 9 | } 10 | 11 | 12 | def compute_metrics_dataset( 13 | network, 14 | test_dataset, 15 | threshold, 16 | test_metrics=["precision", "recall", "f1"], 17 | ): 18 | """ 19 | Computes metrics on current net for test_dataset, using threshold 20 | as classification threshold 21 | """ 22 | 23 | metrics = { 24 | score: score_function for score, score_function in 25 | available_score_functions.items() if score in test_metrics 26 | } 27 | 28 | metrics_test = [{ 29 | metric: [] 30 | for metric in metrics.keys() 31 | } for _ in range(network.number_of_classes - 1)] 32 | 33 | all_predicted_events = network.predict_dataset( 34 | test_dataset, 35 | threshold, 36 | batch_size=128) 37 | 38 | found_some_events = False 39 | 40 | for event_num in range(network.number_of_classes - 1): 41 | 42 | for record in test_dataset.records: 43 | 44 | # Select current event predictions 45 | predicted_events = all_predicted_events[record][event_num] 46 | 47 | # If no predictions skip record, else some_events = 1 48 | if len(predicted_events) == 0: 49 | continue 50 | 51 | found_some_events = True 52 | 53 | # Select current true events 54 | events = test_dataset.get_record_events(record)[event_num] 55 | 56 | # Compute_metrics(events, predicted_events, threshold) 57 | for metric in metrics.keys(): 58 | metrics_test[event_num][metric].append(metrics[metric]( 59 | predicted_events, 60 | events)) 61 | 62 | # If for any event and record the network predicted events, return -1 63 | if found_some_events is False: 64 | return -1 65 | 66 | for event_num in range(network.number_of_classes - 1): 67 | for metric in metrics.keys(): 68 | metrics_test[event_num][metric] = np.nanmean(np.array(metrics_test[event_num][metric])) 69 | 70 | return metrics_test 71 | -------------------------------------------------------------------------------- /tests/test_trainers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dosed.datasets import BalancedEventDataset 4 | from dosed.models import DOSED3 5 | from dosed.trainers import trainers 6 | 7 | 8 | def test_full_training(): 9 | h5_directory = "./tests/test_files/h5/" 10 | 11 | window = 1 # in seconds 12 | 13 | signals = [ 14 | { 15 | 'h5_path': '/eeg_0', 16 | 'fs': 64, 17 | 'processing': { 18 | "type": "clip_and_normalize", 19 | "args": { 20 | "min_value": -150, 21 | "max_value": 150, 22 | } 23 | } 24 | }, 25 | { 26 | 'h5_path': '/eeg_1', 27 | 'fs' : 64, 28 | 'processing': { 29 | "type": "clip_and_normalize", 30 | "args": { 31 | "min_value": -150, 32 | "max_value": 150, 33 | } 34 | } 35 | } 36 | ] 37 | 38 | events = [ 39 | { 40 | "name": "spindle", 41 | "h5_path": "spindle", 42 | }, 43 | ] 44 | 45 | device = torch.device("cuda") 46 | 47 | dataset = BalancedEventDataset( 48 | h5_directory=h5_directory, 49 | signals=signals, 50 | events=events, 51 | window=window, 52 | fs=64, 53 | minimum_overlap=0.5, 54 | transformations=lambda x: x, 55 | ratio_positive=0.5, 56 | n_jobs=-1, 57 | ) 58 | 59 | # default events 60 | default_event_sizes = [1 * dataset.fs, 0.5 * dataset.fs] 61 | 62 | net = DOSED3( 63 | input_shape=dataset.input_shape, 64 | number_of_classes=dataset.number_of_classes, 65 | detection_parameters={ 66 | "overlap_non_maximum_suppression": 0.5, 67 | "classification_threshold": 0.5, 68 | }, 69 | default_event_sizes=default_event_sizes, 70 | ) 71 | 72 | optimizer_parameters = { 73 | "lr": 5e-3, 74 | "weight_decay": 1e-8, 75 | } 76 | loss_specs = { 77 | "type": "worst_negative_mining", 78 | "parameters": { 79 | "number_of_classes": dataset.number_of_classes, 80 | "device": device, 81 | } 82 | } 83 | 84 | trainer = trainers["adam"]( 85 | net, 86 | optimizer_parameters=optimizer_parameters, 87 | loss_specs=loss_specs, 88 | epochs=2, 89 | ) 90 | 91 | best_net_train, best_metrics_train, best_threshold_train = trainer.train( 92 | dataset, 93 | dataset, 94 | batch_size=12, 95 | ) 96 | 97 | best_net_train.predict_dataset( 98 | dataset, 99 | best_threshold_train, 100 | batch_size=2 101 | ) 102 | -------------------------------------------------------------------------------- /dosed/utils/match_events_to_default_localizations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .jaccard_overlap import jaccard_overlap 4 | from .encode import encode 5 | 6 | 7 | def match_events_localization_to_default_localizations(localizations_default, events, threshold_overlap): 8 | batch = len(events) 9 | 10 | # Find localizations_target and classifications_target by matching 11 | # ground truth localizations to default localizations 12 | number_of_default_events = localizations_default.size(0) 13 | localizations_target = torch.Tensor(batch, number_of_default_events, 2) 14 | classifications_target = torch.LongTensor(batch, number_of_default_events) 15 | 16 | for batch_index in range(batch): 17 | 18 | # If no event add default value to predict (will never be used anyway) 19 | # And class 0 == backgroung 20 | if events[batch_index].numel() == 0: 21 | localizations_target[batch_index][:, :] = torch.FloatTensor( 22 | [[-1, 1]]).expand_as(localizations_default) 23 | classifications_target[batch_index] = torch.zeros(localizations_default.size(0)) 24 | continue 25 | 26 | # Else match to most overlapping event and set to background depending on threshold 27 | localizations_truth = events[batch_index][:, :2] 28 | classifications_truth = events[batch_index][:, -1] 29 | localizations_a = localizations_truth 30 | localizations_b = torch.cat( 31 | [(localizations_default[:, 0] - localizations_default[:, 1] / 2).unsqueeze(1), 32 | (localizations_default[:, 0] + localizations_default[:, 1] / 2).unsqueeze(1)], 33 | 1 34 | ) 35 | overlaps = jaccard_overlap(localizations_a, localizations_b) 36 | 37 | # (Bipartite Matching) https://github.com/amdegroot/ssd.pytorch/blob/master/ssd.py 38 | # might be usefull if an event is included in another 39 | _, best_prior_index = overlaps.max(1, keepdim=True) 40 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) 41 | best_truth_idx.squeeze_(0) 42 | best_truth_overlap.squeeze_(0) 43 | best_prior_index.squeeze_(1) 44 | # ensure every gt matches with its prior of max overlap 45 | best_truth_overlap.index_fill_(dim=0, index=best_prior_index, value=2) 46 | for j in range(best_prior_index.size(0)): 47 | best_truth_idx[best_prior_index[j]] = j 48 | 49 | localization_match = localizations_truth[best_truth_idx] 50 | localization_target = encode(localization_match, localizations_default) 51 | classification_target = classifications_truth[best_truth_idx] + 1 # Add class 0! 52 | classification_target[best_truth_overlap < threshold_overlap] = 0 53 | 54 | localizations_target[batch_index][:, :] = localization_target 55 | classifications_target[batch_index] = classification_target.long() 56 | 57 | return localizations_target, classifications_target 58 | -------------------------------------------------------------------------------- /dosed/functions/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import jaccard_overlap 4 | 5 | 6 | def precision_function(): 7 | """returns a function to calculate precision""" 8 | 9 | def calculate_precision(prediction, reference, min_iou=0.3): 10 | """takes 2 event scorings 11 | (in array format [[start1, end1], [start2, end2], ...]) 12 | and outputs the precision. 13 | 14 | Parameters 15 | ---------- 16 | min_iou : float 17 | minimum intersection-over-union with a true event to be considered 18 | a true positive. 19 | """ 20 | 21 | # Compute precision 22 | iou = jaccard_overlap(torch.Tensor(prediction), 23 | torch.Tensor(reference)) 24 | max_iou, _ = iou.max(1) 25 | true_positive = (max_iou >= min_iou).sum().item() 26 | false_positive = len(prediction) - true_positive 27 | precision = true_positive / (true_positive + false_positive) 28 | 29 | return precision 30 | 31 | return calculate_precision 32 | 33 | 34 | def recall_function(): 35 | """returns a function to calculate recall""" 36 | 37 | def calculate_recall(prediction, reference, min_iou=0.3): 38 | """takes 2 event scorings 39 | (in array format [[start1, end1], [start2, end2], ...]) 40 | and outputs the recall. 41 | 42 | Parameters 43 | ---------- 44 | min_iou : float 45 | minimum intersection-over-union with a true event to be considered 46 | a true positive. 47 | """ 48 | 49 | # Compute recall 50 | iou = jaccard_overlap(torch.Tensor(prediction), 51 | torch.Tensor(reference)) 52 | max_iou, _ = iou.max(1) 53 | true_positive = (max_iou >= min_iou).sum().item() 54 | false_negative = len(reference) - true_positive 55 | recall = true_positive / (true_positive + false_negative) 56 | 57 | return recall 58 | 59 | return calculate_recall 60 | 61 | 62 | def f1_function(): 63 | """returns a function to calculate f1 score""" 64 | 65 | def calculate_f1_score(prediction, reference, min_iou=0.3): 66 | """takes 2 event scorings 67 | (in array format [[start1, end1], [start2, end2], ...]) 68 | and outputs the f1 score. 69 | 70 | Parameters 71 | ---------- 72 | min_iou : float 73 | minimum intersection-over-union with a true event to be considered 74 | a true positive. 75 | """ 76 | # Compute precision, recall, f1_score 77 | iou = jaccard_overlap(torch.Tensor(prediction), 78 | torch.Tensor(reference)) 79 | max_iou, _ = iou.max(1) 80 | true_positive = (max_iou >= min_iou).sum().item() 81 | false_positive = len(prediction) - true_positive 82 | false_negative = len(reference) - true_positive 83 | precision = true_positive / (true_positive + false_positive) 84 | recall = true_positive / (true_positive + false_negative) 85 | if precision == 0 or recall == 0: 86 | f1_score = 0 87 | else: 88 | f1_score = 2 * precision * recall / (precision + recall) 89 | 90 | return f1_score 91 | 92 | return calculate_f1_score 93 | -------------------------------------------------------------------------------- /dosed/models/dosed2.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | 5 | from ..functions import Detection 6 | from .base import BaseNet, get_overlerapping_default_events 7 | 8 | 9 | class DOSED2(BaseNet): 10 | 11 | def __init__( 12 | self, 13 | input_shape, 14 | number_of_classes, 15 | detection_parameters, 16 | default_event_sizes, 17 | k_max=8, 18 | ): 19 | super(DOSED2, self).__init__() 20 | self.number_of_channels, self.window_size = input_shape 21 | self.number_of_classes = number_of_classes + 1 # eventness, real events 22 | 23 | detection_parameters["number_of_classes"] = self.number_of_classes 24 | self.detector = Detection(**detection_parameters) 25 | 26 | self.k_max = 8 27 | 28 | # Localizations to default tensor 29 | self.localizations_default = get_overlerapping_default_events( 30 | window_size=self.window_size, 31 | default_event_sizes=default_event_sizes 32 | ) 33 | 34 | # model 35 | self.spatial_filtering = None 36 | if self.number_of_channels > 1: 37 | self.spatial_filtering = nn.Conv2d( 38 | in_channels=1, 39 | out_channels=self.number_of_channels, 40 | kernel_size=(self.number_of_channels, 1), 41 | padding=0) 42 | 43 | self.blocks = nn.ModuleList( 44 | [ 45 | nn.Sequential( 46 | OrderedDict([ 47 | ("conv_{}".format(k), nn.Conv2d( 48 | in_channels=4 * (2 ** (k - 1)) if k > 1 else 1, 49 | out_channels=4 * (2 ** k), 50 | kernel_size=(1, 3), 51 | padding=(0, 1) 52 | )), 53 | ("batchnorm_{}".format(k), nn.BatchNorm2d(4 * (2 ** k))), 54 | ("relu_{}".format(k), nn.ReLU()), 55 | ("max_pooling_{}".format(k), nn.MaxPool2d(kernel_size=(1, 2))), 56 | ]) 57 | ) for k in range(1, self.k_max + 1) 58 | ] 59 | ) 60 | self.localizations = nn.Conv2d( 61 | in_channels=4 * (2 ** self.k_max), 62 | out_channels=2 * len(self.localizations_default), 63 | kernel_size=(self.number_of_channels, int(self.window_size / (2 ** (self.k_max)))), 64 | padding=(0, 0), 65 | ) 66 | 67 | self.classifications = nn.Conv2d( 68 | in_channels=4 * (2 ** self.k_max), 69 | out_channels=self.number_of_classes * len(self.localizations_default), 70 | kernel_size=(self.number_of_channels, int(self.window_size / (2 ** (self.k_max)))), 71 | padding=(0, 0), 72 | ) 73 | 74 | def forward(self, x): 75 | batch = x.size(0) 76 | x = x.view(batch, 1, self.number_of_channels, -1) 77 | 78 | if self.spatial_filtering: 79 | x = self.spatial_filtering(x) 80 | x = x.transpose(2, 1) 81 | 82 | for block in self.blocks: 83 | x = block(x) 84 | 85 | localizations = self.localizations(x).squeeze().view(batch, -1, 2) 86 | classifications = self.classifications(x).squeeze().view(batch, -1, self.number_of_classes) 87 | 88 | return localizations, classifications, self.localizations_default 89 | -------------------------------------------------------------------------------- /dosed/functions/simple_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class DOSEDSimpleLoss(nn.Module): 6 | """Loss function inspired from https://github.com/amdegroot/ssd.pytorch""" 7 | 8 | def __init__(self, 9 | number_of_classes, 10 | device, 11 | ): 12 | super(DOSEDSimpleLoss, self).__init__() 13 | self.device = device 14 | self.number_of_classes = number_of_classes + 1 # eventlessness 15 | 16 | def localization_loss(self, positive, localizations, localizations_target): 17 | # Localization Loss (Smooth L1) 18 | positive_expanded = positive.unsqueeze(positive.dim()).expand_as( 19 | localizations) 20 | loss_localization = F.smooth_l1_loss( 21 | localizations[positive_expanded].view(-1, 2), 22 | localizations_target[positive_expanded].view(-1, 2), 23 | reduction="sum") 24 | return loss_localization 25 | 26 | def get_negative_index(self, positive, classifications, 27 | classifications_target): 28 | negative = (classifications_target == 0) 29 | return negative 30 | 31 | def get_classification_loss(self, index, classifications, 32 | classifications_target): 33 | index_expanded = index.unsqueeze(2).expand_as(classifications) 34 | 35 | loss_classification = F.cross_entropy( 36 | classifications[index_expanded.gt(0) 37 | ].view(-1, self.number_of_classes), 38 | classifications_target[index.gt(0)], 39 | reduction="sum", 40 | ) 41 | return loss_classification 42 | 43 | def forward(self, localizations, classifications, localizations_target, 44 | classifications_target): 45 | 46 | positive = classifications_target > 0 47 | negative = self.get_negative_index(positive, classifications, 48 | classifications_target) 49 | 50 | number_of_positive_all = positive.long().sum().float() 51 | number_of_negative_all = negative.long().sum().float() 52 | 53 | # loc loss 54 | loss_localization = self.localization_loss(positive, localizations, 55 | localizations_target) 56 | 57 | # + Classification loss 58 | loss_classification_positive = 0 59 | if number_of_positive_all > 0: 60 | loss_classification_positive = self.get_classification_loss( 61 | positive, classifications, classifications_target) 62 | 63 | # - Classification loss 64 | loss_classification_negative = 0 65 | if number_of_negative_all > 0: 66 | loss_classification_negative = self.get_classification_loss( 67 | negative, classifications, classifications_target) 68 | 69 | # Loss: sum 70 | loss_classification_positive_normalized = ( 71 | loss_classification_positive / 72 | number_of_positive_all) 73 | loss_classification_negative_normalized = ( 74 | loss_classification_negative / 75 | number_of_negative_all) 76 | loss_localization_normalized = loss_localization / number_of_positive_all 77 | 78 | return (loss_classification_positive_normalized, 79 | loss_classification_negative_normalized, 80 | loss_localization_normalized) 81 | -------------------------------------------------------------------------------- /dosed/models/dosed1.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from ..functions import Detection 11 | from .base import BaseNet 12 | 13 | 14 | class DOSED1(BaseNet): 15 | 16 | def __init__( 17 | self, 18 | input_shape, 19 | number_of_classes, 20 | detection_parameters, 21 | duration=256, 22 | k_max=8, 23 | rho=2 24 | ): 25 | super(DOSED1, self).__init__() 26 | self.number_of_channels, self.window_size = input_shape 27 | self.number_of_classes = number_of_classes + 1 # eventness, real events 28 | 29 | detection_parameters["number_of_classes"] = self.number_of_classes 30 | self.detector = Detection(**detection_parameters) 31 | 32 | self.localizations_default = [] 33 | 34 | self.rho = rho 35 | self.k_max = k_max 36 | self.duration = duration 37 | self.number_of_default_events = self.window_size * self.rho / (2 ** self.k_max) 38 | assert self.number_of_default_events % 1 == 0 39 | 40 | # model 41 | self.spatial_filtering = None 42 | if self.number_of_channels > 1: 43 | self.spatial_filtering = nn.Conv2d( 44 | in_channels=1, 45 | out_channels=self.number_of_channels, 46 | kernel_size=(self.number_of_channels, 1), 47 | padding=0) 48 | 49 | self.blocks = nn.ModuleList( 50 | [ 51 | nn.Sequential( 52 | OrderedDict([ 53 | ("conv_{}".format(k), nn.Conv2d( 54 | in_channels=4 * (2 ** (k - 1)) if k > 1 else 1, 55 | out_channels=4 * (2 ** k), 56 | kernel_size=(1, 3), 57 | padding=(0, 1) 58 | )), 59 | ("batchnorm_{}".format(k), nn.BatchNorm2d(4 * (2 ** k))), 60 | ("relu_{}".format(k), nn.ReLU()), 61 | ("max_pooling_{}".format(k), nn.MaxPool2d(kernel_size=(1, 2))), 62 | ]) 63 | ) for k in range(1, self.k_max + 1) 64 | ] 65 | ) 66 | 67 | self.localizations = nn.Conv2d( 68 | in_channels=4 * (2 ** self.k_max), 69 | out_channels=2 * self.rho, 70 | kernel_size=(self.number_of_channels, 3), 71 | padding=(0, 1), 72 | ) 73 | 74 | self.classifications = nn.Conv2d( 75 | in_channels=4 * (2 ** self.k_max), 76 | out_channels=self.number_of_classes * self.rho, 77 | kernel_size=(self.number_of_channels, 3), 78 | padding=(0, 1), 79 | ) 80 | 81 | # Localizations to default tensor 82 | self.localizations_default = torch.Tensor([ 83 | [((2 ** self.k_max)) * (0.5 + i) / (self.rho * self.window_size), 84 | self.duration / self.window_size] for i in range(int(self.number_of_default_events)) 85 | ] 86 | ) 87 | 88 | def forward(self, x): 89 | batch = x.size(0) 90 | x = x.view(batch, 1, self.number_of_channels, -1) 91 | 92 | if self.spatial_filtering: 93 | x = self.spatial_filtering(x) 94 | x = x.transpose(2, 1) 95 | 96 | for block in self.blocks: 97 | x = block(x) 98 | localizations = self.localizations(x).view( 99 | batch, 100 | 2 * self.rho, 101 | int(self.window_size / (2 ** self.k_max)) 102 | ).permute(0, 2, 1).contiguous().view(batch, -1, 2) 103 | classifications = self.classifications(x).view( 104 | batch, 105 | self.number_of_classes * self.rho, 106 | int(self.window_size / (2 ** self.k_max)) 107 | ).permute(0, 2, 1).contiguous().view(batch, -1, self.number_of_classes) 108 | 109 | return localizations, classifications, self.localizations_default 110 | -------------------------------------------------------------------------------- /dosed/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Some simple logging functionality. 4 | 5 | Logs to a tab-separated-values file (./logs/progress.txt) 6 | 7 | """ 8 | 9 | import os 10 | import json 11 | import time 12 | import os.path as osp 13 | from .colorize import colorize 14 | import tempfile 15 | 16 | 17 | class Logger: 18 | 19 | """ 20 | A logger to track training parameters. 21 | """ 22 | 23 | def __init__(self, 24 | num_events, 25 | output_dir=None, 26 | output_fname='train_history.json', 27 | metrics=["precision", "recall", "f1"], 28 | name_events=["event_type_1", "event_type_2"], 29 | ): 30 | """ 31 | Initialize a Logger. 32 | """ 33 | 34 | assert len(name_events) == num_events 35 | self.name_events = name_events 36 | self.metrics = metrics 37 | self.output_fname = output_fname 38 | self.output_dir = output_dir if output_dir is not None else tempfile.mkdtemp() 39 | if not os.path.isdir(self.output_dir): 40 | os.mkdir(self.output_dir) 41 | self.output_file = osp.join(self.output_dir, output_fname) 42 | print(colorize("Logging data to %s" % self.output_file, 'green', 43 | bold=True)) 44 | 45 | self.num_events = num_events 46 | self.history_time = [] 47 | self.history_loc_loss = {"train": [], "validation": []} 48 | self.history_class_pos_loss = {"train": [], "validation": []} 49 | self.history_class_neg_loss = {"train": [], "validation": []} 50 | self.history_metrics = [] 51 | self.current_epoch_metrics = { 52 | name_event: {metric: [] for metric in self.metrics} 53 | for name_event in self.name_events 54 | } 55 | 56 | def log_msg(self, msg, color='green'): 57 | """ Print a colorized message to stdout. """ 58 | print(colorize(msg, color, bold=True)) 59 | 60 | def add_new_loss(self, loc_loss, class_pos_loss, class_neg_loss, 61 | mode="validation"): 62 | """ Adds loss values of a new epoch. Call one time per epoch """ 63 | self.history_loc_loss[mode].append(loc_loss) 64 | self.history_class_pos_loss[mode].append(class_pos_loss) 65 | self.history_class_neg_loss[mode].append(class_neg_loss) 66 | 67 | def add_new_metrics(self, metrics): 68 | """ 69 | Adds metric values to the current epoch metrics. 70 | Call as many times per epoch as required. 71 | """ 72 | assert len(metrics[0]) == self.num_events 73 | for num_event, event in enumerate(self.name_events): 74 | for metric in self.metrics: 75 | self.current_epoch_metrics[event][metric].append( 76 | (metrics[0][num_event][metric], metrics[1]) 77 | ) 78 | 79 | def add_current_metrics_to_history(self): 80 | """ 81 | Adds current_epoch_metrics to history and resets the variable. 82 | Call at the end of each epoch 83 | """ 84 | self.history_metrics.append(self.current_epoch_metrics) 85 | self.history_time.append(time.time()) 86 | self.current_epoch_metrics = { 87 | name_event: {metric: [] for metric in self.metrics} 88 | for name_event in self.name_events 89 | } 90 | 91 | def dump_train_history(self): 92 | """ Dump training history into a .json file """ 93 | 94 | if len(self.history_loc_loss["train"]) != len( 95 | self.history_class_pos_loss["train"]) or len( 96 | self.history_class_pos_loss["train"]) != len( 97 | self.history_class_neg_loss["train"]) or len( 98 | self.history_class_neg_loss["train"]) != len(self.history_metrics): 99 | print(colorize('Warning: length of loss or metrics not consistent', 100 | 'red')) 101 | 102 | train_history = {} 103 | train_history["loc_loss"] = self.history_loc_loss 104 | train_history["class_pos_loss"] = self.history_class_pos_loss 105 | train_history["class_neg_loss"] = self.history_class_neg_loss 106 | train_history["metrics"] = self.history_metrics 107 | train_history["time_stamps"] = self.history_time 108 | json.dump(train_history, 109 | open(self.output_file, 'w'), 110 | indent=4) 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![CircleCI](https://circleci.com/gh/Dreem-Organization/dosed.svg?style=svg&circle-token=7b6f5fd8d3db49d25417b269c601296b7eebd64f)](https://circleci.com/gh/Dreem-Organization/dosed) 2 | 3 | ## Dreem One Shot Event Detector (DOSED) 4 | 5 | This repository contains a functional implementation of DOSED, a deep learning method proposed first in: 6 | 7 | Stanislas Chambon, Valentin Thorey, Pierrick J. Arnal, Emmanuel Mignot, Alexandre Gramfort 8 | A deep learning architecture to detect events in EEG signals during sleep 9 | IEEE 28th International Workshop on Machine Learning for Signal Processing (MLSP), 2018 10 | https://arxiv.org/abs/1807.05981 11 | 12 | and extended in: 13 | 14 | Stanislas Chambon, Valentin Thorey, Pierrick J. Arnal, Emmanuel Mignot, Alexandre Gramfort. 15 | DOSED: a deep learning approach to detect multiple sleep micro-events in EEG signal 16 | https://arxiv.org/abs/1812.04079 17 | 18 | also used in: 19 | 20 | Valentin Thorey and Albert Bou Hernandez and Pierrick J. Arnal and Emmanuel H. During. 21 | AI vs Humans for the diagnosis of sleep apnea 22 | https://arxiv.org/abs/1906.09936 23 | 24 | ### Introduction 25 | 26 | DOSED in a deep learning approach to jointly predicts locations, durations and types of events in time series. 27 | It was inspired by computer vision object detectors such as YOLO and SSD and relies on a convolutional neural network that builds a feature representation from raw input signals, 28 | as well as two modules performing localization and classification respectively. DOSED can be easily adapt to detect events of any sort. 29 | 30 | ![dosed_detection_image](https://github.com/Dreem-Organization/dosed/blob/master/dosed_detection.png) 31 | 32 | ### Citing DOSED 33 | 34 | @inproceedings{chambon2018deep, 35 | title={A deep learning architecture to detect events in EEG signals during sleep}, 36 | author={Chambon, Stanislas and Thorey, Valentin and Arnal, Pierrick J and Mignot, Emmanuel and Gramfort, Alexandre}, 37 | booktitle={2018 IEEE 28th International Workshop on Machine Learning for Signal Processing (MLSP)}, 38 | pages={1--6}, 39 | year={2018}, 40 | organization={IEEE} 41 | } 42 | 43 | @article{chambon2018dosed, 44 | title={DOSED: a deep learning approach to detect multiple sleep micro-events in EEG signal}, 45 | author={Chambon, Stanislas and Thorey, Valentin and Arnal, Pierrick J and Mignot, Emmanuel and Gramfort, Alexandre}, 46 | journal={arXiv preprint arXiv:1812.04079}, 47 | year={2018} 48 | } 49 | 50 | ### Minimum example 51 | 52 | The folder */minimum_example* contains all necessary code to train a spindle detection model on EEG signals. 53 | 54 | We provide a dataset composed of 21 recordings with two EEG central channels downsampled at 64Hz on which spindles have been annotated. The data was collected at [Dreem](http://www.dreem.com) with a Polysomnography device. 55 | 56 | The example works out-of-the-box given the following considerations. 57 | 58 | #### 1. Package requirements 59 | 60 | Packages detailed in *requirements.txt* need to be installed for the example to work. 61 | 62 | 63 | #### 2. Minimum example 64 | 65 | A minimum example is provided in the folder */minimum\_example* directory. 66 | 67 | Running the script ipython notebook *download_and_data_format_explanation.ipynb* or run `make download_example` to download, pre-processes training data. 68 | 69 | ##### H5 data format 70 | 71 | To work with different datasets, and hence data format, we first require you to convert you original 72 | data and annotation into H5 files for each record. *download_and_data_format_explanation.ipynb* and *to_h5.py* provides detailed explanation and an example of that process. 73 | 74 | Required structure for the .h5 files is the following: 75 | 76 | ``` 77 | / # root of the h5 file 78 | 79 | -> /path/to/signal_1 80 | + attribute "fs" # sampling frequency 81 | 82 | -> /path/to/signal_2 83 | + attribute "fs" 84 | 85 | -> /path/to/signal_3 86 | + attribute "fs" 87 | 88 | -> ... # add as many signals as desired 89 | 90 | 91 | -> /path/to/event_1/ 92 | -> /start # array containing start position of each event with respect to the beginning of the recording (in seconds). 93 | -> /duration # array containing duration of each event (in seconds). 94 | 95 | -> /path/to/event_2/ 96 | -> /start 97 | -> /duration 98 | 99 | -> ... # add as many events as desired 100 | ``` 101 | 102 | This code is the only dataset-specific code that you will need to write. 103 | 104 | #### Training and testing 105 | 106 | The jupyter notebook *train\_and\_evaluate\_dosed.ipynb* goes through the training process in detail, describing all important training parameters. It also explains how to generate predictions, and provides a plot of a spindle detection. 107 | -------------------------------------------------------------------------------- /dosed/models/dosed3.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | 4 | import torch.nn as nn 5 | 6 | from ..functions import Detection 7 | from .base import BaseNet, get_overlerapping_default_events 8 | 9 | 10 | class DOSED3(BaseNet): 11 | 12 | def __init__(self, 13 | input_shape, 14 | number_of_classes, 15 | detection_parameters, 16 | default_event_sizes, 17 | k_max=6, 18 | kernel_size=5, 19 | pdrop=0.1, 20 | fs=256): 21 | 22 | super(DOSED3, self).__init__() 23 | self.number_of_channels, self.window_size = input_shape 24 | self.number_of_classes = number_of_classes + 1 # eventless, real events 25 | 26 | detection_parameters["number_of_classes"] = self.number_of_classes 27 | self.detector = Detection(**detection_parameters) 28 | 29 | self.k_max = k_max 30 | self.kernel_size = kernel_size 31 | self.pdrop = pdrop 32 | 33 | if max(default_event_sizes) > self.window_size: 34 | warnings.warn("Detected default_event_sizes larger than" 35 | " input_shape! Consider reducing them") 36 | 37 | # Localizations to default tensor 38 | self.localizations_default = get_overlerapping_default_events( 39 | window_size=self.window_size, 40 | default_event_sizes=default_event_sizes 41 | ) 42 | 43 | # model 44 | self.blocks = nn.ModuleList( 45 | [ 46 | nn.Sequential( 47 | OrderedDict([ 48 | ("conv_{}".format(k - 1), nn.Conv1d( 49 | in_channels=4 * (2 ** (k - 1)) if k > 1 else self.number_of_channels, 50 | out_channels=4 * (2 ** k), 51 | kernel_size=self.kernel_size, 52 | padding=2 53 | )), 54 | ("batchnorm_{}".format(k - 1), nn.BatchNorm1d(4 * (2 ** k))), 55 | ("relu_{}".format(k), nn.ReLU()), 56 | ("dropput_{}".format(k), nn.Dropout(self.pdrop)), 57 | ("max_pooling_{}".format(k), nn.MaxPool1d(kernel_size=2)), 58 | ]) 59 | ) for k in range(1, self.k_max + 1) 60 | ] 61 | ) 62 | self.localizations = nn.Conv1d( 63 | in_channels=4 * (2 ** (self.k_max)), 64 | out_channels=2 * len(self.localizations_default), 65 | kernel_size=int(self.window_size / (2 ** (self.k_max))), 66 | padding=0, 67 | ) 68 | 69 | self.classifications = nn.Conv1d( 70 | in_channels=4 * (2 ** (self.k_max)), 71 | out_channels=self.number_of_classes * len(self.localizations_default), 72 | kernel_size=int(self.window_size / (2 ** (self.k_max))), 73 | padding=0, 74 | ) 75 | 76 | self.print_info_architecture(fs) 77 | 78 | def forward(self, x): 79 | batch = x.size(0) 80 | for block in self.blocks: 81 | x = block(x) 82 | localizations = self.localizations(x).squeeze().view(batch, -1, 2) 83 | classifications = self.classifications(x).squeeze().view(batch, -1, self.number_of_classes) 84 | 85 | return localizations, classifications, self.localizations_default 86 | 87 | def print_info_architecture(self, fs): 88 | 89 | size = self.window_size 90 | receptive_field = 0 91 | print("\nInput feature map size: {}".format(size)) 92 | print("Input receptive field: {}".format(receptive_field)) 93 | print("Input size in seconds: {} s".format(size / fs)) 94 | print("Input receptive field in seconds: {} s \n".format(receptive_field / fs)) 95 | 96 | kernal_size = self.kernel_size 97 | 98 | size //= 2 99 | receptive_field = kernal_size + 1 100 | print("After layer 1:") 101 | print("\tFeature map size: {}".format(size)) 102 | print("\tReceptive field: {}".format(receptive_field)) 103 | print("\tReceptive field in seconds: {} s".format(receptive_field / fs)) 104 | 105 | for layer in range(2, self.k_max + 1): 106 | size //= 2 107 | receptive_field += (kernal_size // 2) * 2 * 2 ** (layer - 1) # filter 108 | receptive_field += 2 ** (layer - 1) # max_pool 109 | print("After layer {}:".format(layer)) 110 | print("\tFeature map size: {}".format(size)) 111 | print("\tReceptive field: {}".format(receptive_field)) 112 | print("\tReceptive field in seconds: {} s".format( 113 | receptive_field / fs)) 114 | print("\n") 115 | -------------------------------------------------------------------------------- /dosed/models/base.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | import tempfile 3 | import json 4 | import shutil 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from ..utils import binary_to_array 10 | 11 | 12 | class BaseNet(nn.Module): 13 | 14 | def __init__(self): 15 | super(BaseNet, self).__init__() 16 | 17 | @property 18 | def device(self): 19 | try: 20 | out = next(self.parameters()).device 21 | return (out if isinstance(out, torch.device) 22 | else torch.device('cpu')) 23 | except Exception: 24 | return torch.device('cpu') 25 | 26 | def predict(self, x): 27 | localizations, classifications, localizations_default = self.forward(x) 28 | localizations_default = localizations_default.to(self.device) 29 | return self.detector(localizations, classifications, localizations_default) 30 | 31 | def save(self, filename, net_parameters): 32 | with tarfile.open(filename, "w") as tar: 33 | temporary_directory = tempfile.mkdtemp() 34 | name = "{}/net_params.json".format(temporary_directory) 35 | json.dump(net_parameters, open(name, "w")) 36 | tar.add(name, arcname="net_params.json") 37 | name = "{}/state.torch".format(temporary_directory) 38 | torch.save(self.state_dict(), name) 39 | tar.add(name, arcname="state.torch") 40 | shutil.rmtree(temporary_directory) 41 | return filename 42 | 43 | @classmethod 44 | def load(cls, filename, use_device=torch.device('cpu')): 45 | with tarfile.open(filename, "r") as tar: 46 | net_parameters = json.loads( 47 | tar.extractfile("net_params.json").read().decode("utf-8")) 48 | path = tempfile.mkdtemp() 49 | tar.extract("state.torch", path=path) 50 | net = cls(**net_parameters) 51 | net.load_state_dict( 52 | torch.load( 53 | path + "/state.torch", 54 | map_location=use_device, 55 | ) 56 | ) 57 | return net, net_parameters 58 | 59 | def predict_dataset(self, 60 | inference_dataset, 61 | threshold, 62 | overlap_factor=0.5, 63 | batch_size=128, 64 | ): 65 | """ 66 | Predicts events in inference_dataset. 67 | """ 68 | 69 | # Set network to eval mode 70 | self.eval() 71 | 72 | # Set network prediction parameters 73 | self.detector.classification_threshold = threshold 74 | window_size = inference_dataset.window_size 75 | window = inference_dataset.window 76 | overlap = window * overlap_factor 77 | 78 | # List of dicts, to save predictions of each class per record 79 | predictions = {} 80 | for record in inference_dataset.records: 81 | predictions[record] = [] 82 | result = np.zeros((self.number_of_classes - 1, 83 | inference_dataset.signals[record]["size"])) 84 | for signals, times in inference_dataset.get_record_batch( 85 | record, 86 | batch_size=int(batch_size), 87 | stride=overlap): 88 | x = signals.to(self.device) 89 | batch_predictions = self.predict(x) 90 | 91 | for events, time in zip(batch_predictions, times): 92 | for event in events: 93 | start = int(round(event[0] * window_size + time[0])) 94 | stop = int(round(event[1] * window_size + time[0])) 95 | result[event[2], start:stop] = 1 96 | 97 | predicted_events = [binary_to_array(k) for k in result] 98 | assert len(predicted_events) == self.number_of_classes - 1 99 | for event_num in range(self.number_of_classes - 1): 100 | predictions[record].append(predicted_events[event_num]) 101 | 102 | return predictions 103 | 104 | @property 105 | def nelement(self): 106 | cpt = 0 107 | for p in self.parameters(): 108 | cpt += p.nelement() 109 | return cpt 110 | 111 | 112 | def get_overlerapping_default_events(window_size, default_event_sizes, factor_overlap=2): 113 | window_size = window_size 114 | default_event_sizes = default_event_sizes 115 | factor_overlap = factor_overlap 116 | default_events = [] 117 | for default_event_size in default_event_sizes: 118 | overlap = default_event_size / factor_overlap 119 | number_of_default_events = int(window_size / overlap) 120 | default_events.extend( 121 | [(overlap * (0.5 + i) / window_size, default_event_size / window_size) 122 | for i in range(number_of_default_events)] 123 | ) 124 | return torch.Tensor(default_events) 125 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import torch 4 | 5 | from dosed.models import DOSED1, DOSED2, DOSED3 6 | 7 | 8 | def test_dosed1(): 9 | batch_size = 32 10 | number_of_channels = 2 11 | window_duration = 10 12 | fs = 256 13 | x = torch.rand(batch_size, number_of_channels, window_duration * fs) 14 | 15 | # number of classes 16 | number_of_classes = 3 17 | 18 | # default events 19 | default_event_duration = 1 20 | overlap_default_event = 2 21 | 22 | net = DOSED1( 23 | input_shape=(number_of_channels, window_duration * fs), 24 | number_of_classes=number_of_classes, 25 | detection_parameters={ 26 | "overlap_non_maximum_suppression": 0.5, 27 | "classification_threshold": 0.5, 28 | }, 29 | duration=default_event_duration * fs, 30 | rho=overlap_default_event 31 | ) 32 | localizations, classifications, localizations_default = net.forward(x) 33 | number_of_default_events = int(window_duration / default_event_duration * overlap_default_event) 34 | assert localizations.shape == (batch_size, number_of_default_events, 2) 35 | assert classifications.shape == (batch_size, number_of_default_events, number_of_classes + 1) 36 | assert localizations_default.shape == (number_of_default_events, 2) 37 | 38 | 39 | def test_dosed2(): 40 | batch_size = 32 41 | number_of_channels = 2 42 | window_duration = 10 43 | fs = 256 44 | x = torch.rand(batch_size, number_of_channels, window_duration * fs) 45 | 46 | # number of classes 47 | number_of_classes = 3 48 | 49 | # default events 50 | default_event_sizes = [1 * fs, 2 * fs] 51 | 52 | net = DOSED2( 53 | input_shape=(number_of_channels, window_duration * fs), 54 | number_of_classes=number_of_classes, 55 | detection_parameters={ 56 | "overlap_non_maximum_suppression": 0.5, 57 | "classification_threshold": 0.5, 58 | }, 59 | default_event_sizes=default_event_sizes, 60 | ) 61 | localizations, classifications, localizations_default = net.forward(x) 62 | number_of_default_events = sum([ 63 | int(window_duration * fs / default_event_size * 2) 64 | for default_event_size in default_event_sizes] 65 | ) 66 | assert localizations.shape == (batch_size, number_of_default_events, 2) 67 | assert classifications.shape == (batch_size, number_of_default_events, number_of_classes + 1) 68 | assert localizations_default.shape == (number_of_default_events, 2) 69 | 70 | 71 | def test_dosed3(): 72 | batch_size = 32 73 | number_of_channels = 2 74 | window_duration = 10 75 | fs = 256 76 | x = torch.rand(batch_size, number_of_channels, window_duration * fs) 77 | 78 | # number of classes 79 | number_of_classes = 3 80 | 81 | # default events 82 | default_event_sizes = [1 * fs, 2 * fs] 83 | 84 | net = DOSED3( 85 | input_shape=(number_of_channels, window_duration * fs), 86 | number_of_classes=number_of_classes, 87 | detection_parameters={ 88 | "overlap_non_maximum_suppression": 0.5, 89 | "classification_threshold": 0.5, 90 | }, 91 | default_event_sizes=default_event_sizes, 92 | ) 93 | localizations, classifications, localizations_default = net.forward(x) 94 | number_of_default_events = sum([ 95 | int(window_duration * fs / default_event_size * 2) 96 | for default_event_size in default_event_sizes] 97 | ) 98 | assert localizations.shape == (batch_size, number_of_default_events, 2) 99 | assert classifications.shape == (batch_size, number_of_default_events, number_of_classes + 1) 100 | assert localizations_default.shape == (number_of_default_events, 2) 101 | 102 | 103 | def test_save_load(): 104 | batch_size = 32 105 | number_of_channels = 2 106 | window_duration = 10 107 | fs = 64 108 | x = torch.rand(batch_size, number_of_channels, window_duration * fs) 109 | 110 | net_parameters = { 111 | "input_shape": [number_of_channels, window_duration * fs], 112 | "number_of_classes": 3, 113 | "detection_parameters": { 114 | "overlap_non_maximum_suppression": 0.5, 115 | "classification_threshold": 0.5, 116 | }, 117 | "default_event_sizes": [64], 118 | } 119 | net = DOSED3( 120 | **net_parameters 121 | ) 122 | filename = tempfile.mkdtemp() + "/lol.lol" 123 | net.save(filename, net_parameters) 124 | 125 | net_loaded, net_parameters_loaded = net.load(filename) 126 | 127 | assert net_parameters_loaded == net_parameters 128 | 129 | net.eval() 130 | localizations, classifications, localizations_default = net.forward(x) 131 | net_loaded.eval() 132 | localizations_, classifications_, localizations_default_ = net_loaded.forward(x) 133 | 134 | assert localizations.tolist() == localizations_.tolist() 135 | assert classifications.tolist() == classifications_.tolist() 136 | assert localizations_default.tolist() == localizations_default_.tolist() 137 | 138 | 139 | def test_nelement(): 140 | net_parameters = { 141 | "input_shape": [1, 20], 142 | "number_of_classes": 3, 143 | "detection_parameters": { 144 | "overlap_non_maximum_suppression": 0.5, 145 | "classification_threshold": 0.5, 146 | }, 147 | "default_event_sizes": [10], 148 | "k_max": 1 149 | } 150 | net = DOSED3( 151 | **net_parameters 152 | ) 153 | assert net.nelement == 2008 154 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import pytest 3 | import time 4 | import os 5 | from unittest.mock import patch 6 | 7 | from dosed.datasets import BalancedEventDataset, EventDataset, get_train_validation_test 8 | 9 | 10 | @pytest.fixture 11 | def h5_directory(): 12 | return "./tests/test_files/h5/" 13 | 14 | 15 | @pytest.fixture 16 | def records(h5_directory): 17 | train, validation, test = get_train_validation_test(h5_directory, 50, 50, seed=2008) 18 | return train + validation + test 19 | 20 | 21 | @pytest.fixture 22 | def signals(): 23 | return [ 24 | { 25 | 'h5_path': '/eeg_0', 26 | 'fs': 64, 27 | 'processing': { 28 | "type": "clip_and_normalize", 29 | "args": { 30 | "min_value": -150, 31 | "max_value": 150, 32 | } 33 | } 34 | }, 35 | { 36 | 'h5_path': '/eeg_1', 37 | 'fs': 64, 38 | 'processing': { 39 | "type": "clip_and_normalize", 40 | "args": { 41 | "min_value": -150, 42 | "max_value": 150, 43 | } 44 | } 45 | } 46 | ] 47 | 48 | 49 | @pytest.fixture 50 | def events(): 51 | return [ 52 | { 53 | "name": "spindle", 54 | "h5_path": "spindle", 55 | } 56 | ] 57 | 58 | 59 | @pytest.fixture 60 | def cache_directory(): 61 | return "./tests/test_files/h5/.cache" 62 | 63 | 64 | def test_dataset(signals, events, h5_directory, records): 65 | 66 | window = 2 67 | 68 | dataset = EventDataset( 69 | h5_directory=h5_directory, 70 | signals=signals, 71 | events=events, 72 | records=sorted(records), 73 | window=window, 74 | fs=64, 75 | minimum_overlap=0.5, 76 | transformations=lambda x: x 77 | ) 78 | 79 | signal, events = dataset[0] 80 | 81 | assert tuple(signal.shape) == (2, int(window * dataset.fs)) 82 | 83 | assert len(dataset) == 360 84 | 85 | assert signal[0][6].tolist() == -0.11056432873010635 86 | 87 | 88 | def test_balanced_dataset_ratio_1(h5_directory, signals, events, records): 89 | 90 | dataset = BalancedEventDataset( 91 | h5_directory=h5_directory, 92 | signals=signals, 93 | events=events, 94 | window=1, 95 | fs=64, 96 | records=None, 97 | minimum_overlap=0.5, 98 | transformations=lambda x: x, 99 | ratio_positive=1, 100 | ) 101 | 102 | signal, events_data = dataset[0] 103 | 104 | assert tuple(signal.shape) == (2, 64) 105 | assert events_data.shape[1] == 3 106 | 107 | number_of_events = sum( 108 | [len(dataset.get_record_events(record)[0]) for record in records] 109 | ) 110 | assert number_of_events == len(dataset) == 103 111 | 112 | assert len(list(dataset.get_record_batch(records[0], 17))) == 22 113 | 114 | 115 | def test_balanced_dataset_ratio_0(h5_directory, signals, events, records): 116 | dataset = BalancedEventDataset( 117 | h5_directory=h5_directory, 118 | signals=signals, 119 | events=events, 120 | window=1, 121 | fs=64, 122 | records=None, 123 | minimum_overlap=0.5, 124 | transformations=lambda x: x, 125 | ratio_positive=0, 126 | ) 127 | 128 | signal, events_data = dataset[0] 129 | 130 | assert len(events_data) == 0 131 | assert len(dataset) == 103 132 | 133 | 134 | def mock_clip_and_normalize(min_value, max_value): 135 | def clipper(x, min_value=min_value, max_value=max_value): 136 | time.sleep(1) 137 | return x 138 | return clipper 139 | 140 | 141 | normalizer = { 142 | "clip_and_normalize": mock_clip_and_normalize 143 | } 144 | 145 | 146 | @patch("dosed.utils.data_from_h5.normalizers", normalizer) 147 | def test_parallel_is_faster(h5_directory, signals, events, records, cache_directory): 148 | 149 | dataset_parameters = { 150 | "h5_directory": h5_directory, 151 | "signals": signals, 152 | "events": events, 153 | "window": 1, 154 | "fs": 64, 155 | "records": None, 156 | "minimum_overlap": 0.5, 157 | "ratio_positive": 0.5, 158 | "cache_data": False, 159 | } 160 | 161 | shutil.rmtree(cache_directory, ignore_errors=True) 162 | t1 = time.time() 163 | BalancedEventDataset( 164 | n_jobs=-1, 165 | **dataset_parameters 166 | ) 167 | t1 = time.time() - t1 168 | 169 | shutil.rmtree(cache_directory, ignore_errors=True) 170 | t2 = time.time() 171 | BalancedEventDataset( 172 | n_jobs=1, 173 | **dataset_parameters, 174 | ) 175 | t2 = time.time() - t2 176 | 177 | assert t2 > t1 178 | 179 | 180 | def test_cache_is_faster(h5_directory, signals, events, records, cache_directory): 181 | dataset_parameters = { 182 | "h5_directory": h5_directory, 183 | "signals": signals, 184 | "events": events, 185 | "window": 1, 186 | "fs": 64, 187 | "records": None, 188 | "minimum_overlap": 0.5, 189 | "ratio_positive": 0.5, 190 | } 191 | 192 | shutil.rmtree(cache_directory, ignore_errors=True) 193 | t1 = time.time() 194 | BalancedEventDataset( 195 | cache_data=True, 196 | **dataset_parameters 197 | ) 198 | t1 = time.time() - t1 199 | 200 | t2 = time.time() 201 | BalancedEventDataset( 202 | cache_data=True, 203 | **dataset_parameters, 204 | ) 205 | t2 = time.time() - t2 206 | 207 | assert t2 < t1 208 | 209 | 210 | def test_cache_no_cache(h5_directory, signals, events, records, cache_directory): 211 | dataset_parameters = { 212 | "h5_directory": h5_directory, 213 | "signals": signals, 214 | "events": events, 215 | "window": 1, 216 | "fs": 64, 217 | "records": None, 218 | "minimum_overlap": 0.5, 219 | "ratio_positive": 0.5, 220 | "n_jobs": -1, 221 | } 222 | 223 | shutil.rmtree(cache_directory, ignore_errors=True) 224 | BalancedEventDataset( 225 | cache_data=False, 226 | **dataset_parameters 227 | ) 228 | assert not os.path.isdir(cache_directory) 229 | 230 | BalancedEventDataset( 231 | cache_data=True, 232 | **dataset_parameters, 233 | ) 234 | assert os.path.isdir(cache_directory) 235 | -------------------------------------------------------------------------------- /dosed/trainers/base.py: -------------------------------------------------------------------------------- 1 | """ Trainer class basic with SGD optimizer """ 2 | 3 | import copy 4 | import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | 11 | from ..datasets import collate 12 | from ..functions import (loss_functions, available_score_functions, compute_metrics_dataset) 13 | from ..utils import (match_events_localization_to_default_localizations, Logger) 14 | 15 | 16 | class TrainerBase: 17 | """Trainer class basic """ 18 | 19 | def __init__( 20 | self, 21 | net, 22 | optimizer_parameters={ 23 | "lr": 0.001, 24 | "weight_decay": 1e-8, 25 | }, 26 | loss_specs={ 27 | "type": "focal", 28 | "parameters": { 29 | "number_of_classes": 1, 30 | "alpha": 0.25, 31 | "gamma": 2, 32 | "device": torch.device("cuda"), 33 | } 34 | }, 35 | metrics=["precision", "recall", "f1"], 36 | epochs=100, 37 | metric_to_maximize="f1", 38 | patience=None, 39 | save_folder=None, 40 | logger_parameters={ 41 | "num_events": 1, 42 | "output_dir": None, 43 | "output_fname": 'train_history.json', 44 | "metrics": ["precision", "recall", "f1"], 45 | "name_events": ["event_type_1"] 46 | }, 47 | threshold_space={ 48 | "upper_bound": 0.85, 49 | "lower_bound": 0.55, 50 | "num_samples": 5, 51 | "zoom_in": False, 52 | }, 53 | matching_overlap=0.5, 54 | ): 55 | 56 | self.net = net 57 | print("Device: ", net.device) 58 | self.loss_function = loss_functions[loss_specs["type"]]( 59 | **loss_specs["parameters"]) 60 | self.optimizer = optim.SGD(net.parameters(), **optimizer_parameters) 61 | self.metrics = { 62 | score: score_function for score, score_function in 63 | available_score_functions.items() 64 | if score in metrics + [metric_to_maximize] 65 | } 66 | self.epochs = epochs 67 | self.threshold_space = threshold_space 68 | self.metric_to_maximize = metric_to_maximize 69 | self.patience = patience if patience else epochs 70 | self.save_folder = save_folder 71 | self.matching_overlap = matching_overlap 72 | self.matching = match_events_localization_to_default_localizations 73 | if logger_parameters is not None: 74 | self.train_logger = Logger(**logger_parameters) 75 | 76 | def on_batch_start(self): 77 | pass 78 | 79 | def on_epoch_end(self): 80 | pass 81 | 82 | def validate(self, validation_dataset, threshold_space): 83 | """ 84 | Compute metrics on validation_dataset net for test_dataset and 85 | select best classification threshold 86 | """ 87 | 88 | best_thresh = -1 89 | best_metrics_epoch = { 90 | metric: -1 91 | for metric in self.metrics.keys() 92 | } 93 | 94 | # Compute predicted_events 95 | thresholds = np.sort( 96 | np.random.uniform(threshold_space["upper_bound"], 97 | threshold_space["lower_bound"], 98 | threshold_space["num_samples"])) 99 | 100 | for threshold in thresholds: 101 | metrics_thresh = compute_metrics_dataset( 102 | self.net, 103 | validation_dataset, 104 | threshold, 105 | ) 106 | 107 | # If 0 events predicted, all superiors thresh's will also predict 0 108 | if metrics_thresh == -1: 109 | if best_thresh in (self.threshold_space["upper_bound"], 110 | self.threshold_space["lower_bound"]): 111 | print( 112 | "Best classification threshold is " + 113 | "in the boundary ({})! ".format(best_thresh) + 114 | "Consider extending threshold range") 115 | return best_metrics_epoch, best_thresh 116 | 117 | # Add to logger 118 | if "train_logger" in vars(self): 119 | self.train_logger.add_new_metrics((metrics_thresh, threshold)) 120 | 121 | # Compute mean metric to maximize across events 122 | mean_metric_to_maximize = np.nanmean( 123 | [m[self.metric_to_maximize] for m in metrics_thresh]) 124 | 125 | if mean_metric_to_maximize >= best_metrics_epoch[ 126 | self.metric_to_maximize]: 127 | best_metrics_epoch = { 128 | metric: np.nanmean( 129 | [m[metric] for m in metrics_thresh]) 130 | for metric in self.metrics.keys() 131 | } 132 | 133 | best_thresh = threshold 134 | 135 | if best_thresh in (threshold_space["upper_bound"], 136 | threshold_space["lower_bound"]): 137 | print("Best classification threshold is " + 138 | "in the boundary ({})! ".format(best_thresh) + 139 | "Consider extending threshold range") 140 | 141 | return best_metrics_epoch, best_thresh 142 | 143 | def get_batch_loss(self, data): 144 | """ Single forward and backward pass """ 145 | 146 | # Get signals and labels 147 | signals, events = data 148 | x = signals.to(self.net.device) 149 | 150 | # Forward 151 | localizations, classifications, localizations_default = self.net.forward(x) 152 | 153 | # Matching 154 | localizations_target, classifications_target = self.matching( 155 | localizations_default=localizations_default, 156 | events=events, 157 | threshold_overlap=self.matching_overlap) 158 | localizations_target = localizations_target.to(self.net.device) 159 | classifications_target = classifications_target.to(self.net.device) 160 | 161 | # Loss 162 | (loss_classification_positive, 163 | loss_classification_negative, 164 | loss_localization) = ( 165 | self.loss_function(localizations, 166 | classifications, 167 | localizations_target, 168 | classifications_target)) 169 | 170 | return loss_classification_positive, \ 171 | loss_classification_negative, \ 172 | loss_localization 173 | 174 | def train(self, train_dataset, validation_dataset, batch_size=128): 175 | """ Metwork training with backprop """ 176 | 177 | dataloader_parameters = { 178 | "num_workers": 0, 179 | "shuffle": True, 180 | "collate_fn": collate, 181 | "pin_memory": True, 182 | "batch_size": batch_size, 183 | } 184 | dataloader_train = DataLoader(train_dataset, **dataloader_parameters) 185 | dataloader_val = DataLoader(validation_dataset, **dataloader_parameters) 186 | 187 | metrics_final = { 188 | metric: 0 189 | for metric in self.metrics.keys() 190 | } 191 | 192 | best_value = -np.inf 193 | best_threshold = None 194 | best_net = None 195 | counter_patience = 0 196 | last_update = None 197 | t = tqdm.tqdm(range(self.epochs,)) 198 | for epoch, _ in enumerate(t): 199 | if epoch != 0: 200 | t.set_postfix( 201 | best_metric_score=best_value, 202 | threshold=best_threshold, 203 | last_update=last_update, 204 | ) 205 | 206 | epoch_loss_classification_positive_train = 0.0 207 | epoch_loss_classification_negative_train = 0.0 208 | epoch_loss_localization_train = 0.0 209 | 210 | epoch_loss_classification_positive_val = 0.0 211 | epoch_loss_classification_negative_val = 0.0 212 | epoch_loss_localization_val = 0.0 213 | 214 | for i, data in enumerate(dataloader_train, 0): 215 | 216 | # On batch start 217 | self.on_batch_start() 218 | 219 | self.optimizer.zero_grad() 220 | 221 | # Set network to train mode 222 | self.net.train() 223 | 224 | (loss_classification_positive, 225 | loss_classification_negative, 226 | loss_localization) = self.get_batch_loss(data) 227 | 228 | epoch_loss_classification_positive_train += \ 229 | loss_classification_positive 230 | epoch_loss_classification_negative_train += \ 231 | loss_classification_negative 232 | epoch_loss_localization_train += loss_localization 233 | 234 | loss = loss_classification_positive \ 235 | + loss_classification_negative \ 236 | + loss_localization 237 | loss.backward() 238 | 239 | # gradient descent 240 | self.optimizer.step() 241 | 242 | epoch_loss_classification_positive_train /= (i + 1) 243 | epoch_loss_classification_negative_train /= (i + 1) 244 | epoch_loss_localization_train /= (i + 1) 245 | 246 | for i, data in enumerate(dataloader_val, 0): 247 | 248 | (loss_classification_positive, 249 | loss_classification_negative, 250 | loss_localization) = self.get_batch_loss(data) 251 | 252 | epoch_loss_classification_positive_val += \ 253 | loss_classification_positive 254 | epoch_loss_classification_negative_val += \ 255 | loss_classification_negative 256 | epoch_loss_localization_val += loss_localization 257 | 258 | epoch_loss_classification_positive_val /= (i + 1) 259 | epoch_loss_classification_negative_val /= (i + 1) 260 | epoch_loss_localization_val /= (i + 1) 261 | 262 | metrics_epoch, threshold = self.validate( 263 | validation_dataset=validation_dataset, 264 | threshold_space=self.threshold_space, 265 | ) 266 | 267 | if self.threshold_space["zoom_in"] and threshold != -1: 268 | threshold_space_size = self.threshold_space["upper_bound"] - \ 269 | self.threshold_space["lower_bound"] 270 | zoom_metrics_epoch, zoom_threshold = self.validate( 271 | validation_dataset=validation_dataset, 272 | threshold_space={ 273 | "upper_bound": threshold + 0.1 * threshold_space_size, 274 | "lower_bound": threshold - 0.1 * threshold_space_size, 275 | "num_samples": self.threshold_space["num_samples"], 276 | }) 277 | if zoom_metrics_epoch[self.metric_to_maximize] > metrics_epoch[ 278 | self.metric_to_maximize]: 279 | metrics_epoch = zoom_metrics_epoch 280 | threshold = zoom_threshold 281 | 282 | if self.save_folder: 283 | self.net.save(self.save_folder + str(epoch) + "_net") 284 | 285 | if metrics_epoch[self.metric_to_maximize] > best_value: 286 | best_value = metrics_epoch[self.metric_to_maximize] 287 | best_threshold = threshold 288 | last_update = epoch 289 | best_net = copy.deepcopy(self.net) 290 | metrics_final = { 291 | metric: metrics_epoch[metric] 292 | for metric in self.metrics.keys() 293 | } 294 | counter_patience = 0 295 | else: 296 | counter_patience += 1 297 | 298 | if counter_patience > self.patience: 299 | break 300 | 301 | self.on_epoch_end() 302 | if "train_logger" in vars(self): 303 | self.train_logger.add_new_loss( 304 | epoch_loss_localization_train.item(), 305 | epoch_loss_classification_positive_train.item(), 306 | epoch_loss_classification_negative_train.item(), 307 | mode="train" 308 | ) 309 | self.train_logger.add_new_loss( 310 | epoch_loss_localization_val.item(), 311 | epoch_loss_classification_positive_val.item(), 312 | epoch_loss_classification_negative_val.item(), 313 | mode="validation" 314 | ) 315 | self.train_logger.add_current_metrics_to_history() 316 | self.train_logger.dump_train_history() 317 | 318 | return best_net, metrics_final, best_threshold 319 | -------------------------------------------------------------------------------- /dosed/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset Class for DOSED training""" 2 | 3 | import os 4 | import h5py 5 | import numpy as np 6 | from numpy.lib.stride_tricks import as_strided 7 | from matplotlib import gridspec 8 | from joblib import Memory, Parallel, delayed 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | 13 | from ..utils import get_h5_data, get_h5_events 14 | 15 | 16 | class EventDataset(Dataset): 17 | 18 | """Extract data and events from h5 files and provide efficient way to retrieve windows with 19 | their corresponding events. 20 | 21 | args 22 | ==== 23 | 24 | h5_directory: 25 | Location of the generic h5 files. 26 | signals: 27 | The signals from the h5 we want to include together with their normalization 28 | events: 29 | The events from the h5 we want to train on 30 | window: 31 | Window size in seconds 32 | downsampling_rate: 33 | Downsampling rate to apply to signals 34 | records: 35 | Use to select subset of records from h5_directory, default is None and uses all available recordings 36 | n_jobs: 37 | Number of process used to extract and normalize signals from h5 files. 38 | cache_data: 39 | Cache results of extraction and normalization of signals from h5_file in h5_directory + "/.cache" 40 | We strongly recommend to keep the default value True to avoid memory overhead. 41 | minimum_overlap: 42 | For an event on the edge to be considered included in a window 43 | ratio_positive: 44 | Sample within a training batch will have a probability of "ratio_positive" to contain at least one spindle 45 | 46 | """ 47 | 48 | def __init__(self, 49 | h5_directory, 50 | signals, 51 | window, 52 | fs, 53 | events=None, 54 | records=None, 55 | n_jobs=1, 56 | cache_data=True, 57 | minimum_overlap=0.5, 58 | transformations=None 59 | ): 60 | 61 | if events: 62 | self.number_of_classes = len(events) 63 | self.transformations = transformations 64 | 65 | # window parameters 66 | self.window = window 67 | 68 | # records (all of them by default) 69 | if records is not None: 70 | for record in records: 71 | assert record in os.listdir(h5_directory) 72 | self.records = records 73 | else: 74 | self.records = [x for x in os.listdir(h5_directory) if x != ".cache"] 75 | 76 | ########################### 77 | # Checks on H5 78 | self.fs = fs 79 | 80 | # check event names 81 | if events: 82 | assert len(set([event["name"] for event in events])) == len(events) 83 | 84 | # ### joblib cache 85 | get_data = get_h5_data 86 | get_events = get_h5_events 87 | if cache_data: 88 | memory = Memory(h5_directory + "/.cache/", mmap_mode="r", verbose=0) 89 | get_data = memory.cache(get_h5_data) 90 | get_events = memory.cache(get_h5_events) 91 | 92 | self.window_size = int(self.window * self.fs) 93 | self.number_of_channels = len(signals) 94 | # used in network architecture 95 | self.input_shape = (self.number_of_channels, self.window_size) 96 | self.minimum_overlap = minimum_overlap # for events on the edge of window_size 97 | 98 | # Open signals and events 99 | self.signals = {} 100 | self.events = {} 101 | self.index_to_record = [] 102 | self.index_to_record_event = [] # link index to record 103 | 104 | # Preprocess signals from records 105 | data = Parallel(n_jobs=n_jobs, prefer="threads")(delayed(get_data)( 106 | filename="{}/{}".format(h5_directory, record), 107 | signals=signals, 108 | fs=fs 109 | ) for record in self.records) 110 | 111 | for record, data in zip(self.records, data): 112 | signal_size = data.shape[-1] 113 | number_of_windows = signal_size // self.window_size 114 | 115 | self.signals[record] = { 116 | "data": data, 117 | "size": signal_size, 118 | } 119 | 120 | self.index_to_record.extend([ 121 | { 122 | "record": record, 123 | "index": x * self.window_size 124 | } for x in range(number_of_windows) 125 | ]) 126 | 127 | if events: 128 | self.events[record] = {} 129 | number_of_events = 0 130 | events_indexes = set() 131 | max_index = signal_size - self.window_size 132 | 133 | for label, event in enumerate(events): 134 | data = get_events( 135 | filename="{}/{}".format(h5_directory, record), 136 | event=event, 137 | fs=self.fs, 138 | ) 139 | 140 | number_of_events += data.shape[-1] 141 | self.events[record][event["name"]] = { 142 | "data": data, 143 | "label": label, 144 | } 145 | 146 | for start, duration in zip(*data): 147 | stop = start + duration 148 | duration_overlap = duration * self.minimum_overlap 149 | 150 | start_valid_index = int(round( 151 | max(0, start + duration_overlap - self.window_size + 1))) 152 | end_valid_index = int(round( 153 | min(max_index + 1, stop - duration_overlap))) 154 | 155 | indexes = list(range(start_valid_index, end_valid_index)) 156 | events_indexes.update(indexes) 157 | 158 | no_events_indexes = set(range(max_index + 1)) 159 | no_events_indexes = list(no_events_indexes.difference(events_indexes)) 160 | events_indexes = list(events_indexes) 161 | 162 | self.index_to_record_event.extend([ 163 | { 164 | "record": record, 165 | "max_index": max_index, 166 | "events_indexes": events_indexes, 167 | "no_events_indexes": no_events_indexes, 168 | } for _ in range(number_of_events) 169 | ]) 170 | 171 | def __len__(self): 172 | return len(self.index_to_record) 173 | 174 | def __getitem__(self, idx): 175 | signal, events = self.get_sample( 176 | record=self.index_to_record[idx]["record"], 177 | index=self.index_to_record[idx]["index"]) 178 | 179 | if self.transformations is not None: 180 | signal = self.transformations(signal) 181 | return signal, events 182 | 183 | def get_valid_events_index(self, index, starts, durations): 184 | """Return the events' indexes that have enough overlap with the given time index 185 | ex: index = 155 186 | starts = [10 140 150 165 2000] 187 | duration = [4 20 10 10 40] 188 | minimum_overlap = 0.5 189 | window_size = 15 190 | return: [2 3] 191 | """ 192 | # Relative start stop 193 | 194 | starts_relative = (starts - index) / self.window_size 195 | durations_relative = durations / self.window_size 196 | stops_relative = starts_relative + durations_relative 197 | 198 | # Find valid start or stop 199 | valid_starts_index = np.where((starts_relative > 0) * 200 | (starts_relative < 1))[0] 201 | valid_stops_index = np.where((stops_relative > 0) * 202 | (stops_relative < 1))[0] 203 | 204 | valid_inside_index = np.where((starts_relative <= 0) * 205 | (stops_relative >= 1))[0] 206 | 207 | # merge them 208 | valid_indexes = set(list(valid_starts_index) + 209 | list(valid_stops_index) + 210 | list(valid_inside_index)) 211 | 212 | # Annotations contains valid index with minimum overlap requirement 213 | events_indexes = [] 214 | for valid_index in valid_indexes: 215 | if (valid_index in valid_starts_index) \ 216 | and (valid_index in valid_stops_index): 217 | events_indexes.append(valid_index) 218 | elif valid_index in valid_starts_index: 219 | if ((1 - starts_relative[valid_index]) / 220 | durations_relative[valid_index]) > self.minimum_overlap: 221 | events_indexes.append(valid_index) 222 | elif valid_index in valid_stops_index: 223 | if ((stops_relative[valid_index]) / durations_relative[valid_index]) \ 224 | > self.minimum_overlap: 225 | events_indexes.append(valid_index) 226 | elif valid_index in valid_inside_index: 227 | if self.window_size / durations[valid_index] > self.minimum_overlap: 228 | events_indexes.append(valid_index) 229 | return events_indexes 230 | 231 | def get_record_events(self, record): 232 | 233 | events = [[] for _ in range(self.number_of_classes)] 234 | 235 | for event_data in self.events[record].values(): 236 | events[event_data["label"]].extend([ 237 | [start, start + duration] 238 | for start, duration in event_data["data"].transpose().tolist() 239 | ]) 240 | 241 | return events 242 | 243 | def get_record_batch(self, record, batch_size, stride=None): 244 | """Return signal data from a specific record as a batch of continuous 245 | windows. Overlap in seconds allows overlapping among windows in the 246 | batch. The last data points will be ignored if their length is 247 | inferior to window_size. 248 | """ 249 | 250 | # stride = overlap_size 251 | # batch_size = batch 252 | 253 | stride = int((stride if stride is not None else self.window) * self.fs) 254 | batch_overlap_size = stride * batch_size # stride at a batch level 255 | read_size = (batch_size - 1) * stride + self.window_size 256 | signal_size = self.signals[record]["size"] 257 | t = np.arange(signal_size) 258 | number_of_batches_in_record = (signal_size - read_size) // batch_overlap_size + 1 259 | 260 | for batch in range(number_of_batches_in_record): 261 | start = batch_overlap_size * batch 262 | stop = batch_overlap_size * batch + read_size 263 | signal = self.signals[record]["data"][:, start:stop] 264 | 265 | signal_strided = torch.FloatTensor( 266 | as_strided( 267 | x=signal, 268 | shape=(batch_size, signal.shape[0], self.window_size), 269 | strides=(signal.strides[1] * stride, signal.strides[0], 270 | signal.strides[1]), 271 | ) 272 | ) 273 | time = t[start:stop] 274 | t_strided = as_strided( 275 | x=time, 276 | shape=(batch_size, self.window_size), 277 | strides=(time.strides[0] * stride, time.strides[0]), 278 | ) 279 | 280 | yield signal_strided, t_strided 281 | 282 | batch_end = ( 283 | signal_size - number_of_batches_in_record * batch_overlap_size - self.window_size 284 | ) // stride + 1 285 | if batch_end > 0: 286 | 287 | read_size_end = (batch_end - 1) * stride + self.window_size 288 | start = batch_overlap_size * number_of_batches_in_record 289 | end = batch_overlap_size * number_of_batches_in_record + read_size_end 290 | signal = self.signals[record]["data"][:, start:end] 291 | 292 | signal_strided = torch.FloatTensor( 293 | as_strided( 294 | x=signal, 295 | shape=(batch_end, signal.shape[0], self.window_size), 296 | strides=(signal.strides[1] * stride, signal.strides[0], 297 | signal.strides[1]), 298 | ) 299 | ) 300 | time = t[start:end] 301 | t_strided = as_strided( 302 | x=time, 303 | shape=(batch_end, self.window_size), 304 | strides=(time.strides[0] * stride, time.strides[0]), 305 | ) 306 | 307 | yield signal_strided, t_strided 308 | 309 | def plot(self, idx, channels): 310 | """Plot events and data from channels for record and index found at 311 | idx""" 312 | 313 | import matplotlib.pyplot as plt 314 | signal, events = self.extract_balanced_data( 315 | record=self.index_to_record_event[idx]["record"], 316 | max_index=self.index_to_record_event[idx]["max_index"]) 317 | 318 | non_valid_indexes = np.where(np.array(channels) is None)[0] 319 | signal = np.delete(signal, non_valid_indexes, axis=0) 320 | channels = [channel for channel in channels if channel is not None][::-1] 321 | 322 | num_signals = len(channels) 323 | signal_size = len(signal[0]) 324 | events_numpy = events.numpy() 325 | plt.figure(figsize=(10 * 4, 2 * num_signals)) 326 | gs = gridspec.GridSpec(num_signals, 1) 327 | gs.update(wspace=0., hspace=0.) 328 | for channel_num, channel in enumerate(channels): 329 | assert signal_size == len(signal[channel_num]) 330 | signal_mean = signal.numpy()[channel_num].mean() 331 | ax = plt.subplot(gs[channel_num, 0]) 332 | ax.set_ylim(-0.55, 0.55) 333 | ax.plot(signal.numpy()[channel_num], alpha=0.3) 334 | for event in events_numpy: 335 | ax.fill([event[0] * signal_size, event[1] * signal_size], 336 | [signal_mean, signal_mean], 337 | alpha=0.5, 338 | linewidth=30, 339 | color="C{}".format(int(event[-1]))) 340 | if channel_num == 0: 341 | # print(EVENT_DICT[event[2]]) 342 | offset = (1. / num_signals) * 1.1 343 | step = (1. / num_signals) * 0.78 344 | plt.gcf().text(0.915, offset + channel_num * step, 345 | channel, fontsize=14) 346 | plt.show() 347 | plt.close() 348 | 349 | def get_sample(self, record, index): 350 | """Return a sample [sata, events] from a record at a particularindex""" 351 | 352 | signal_data = self.signals[record]["data"][:, index: index + self.window_size] 353 | events_data = [] 354 | 355 | for event_name, event in self.events[record].items(): 356 | starts, durations = event["data"][0, :], event["data"][1, :] 357 | 358 | # Relative start stop 359 | starts_relative = (starts - index) / self.window_size 360 | durations_relative = durations / self.window_size 361 | stops_relative = starts_relative + durations_relative 362 | 363 | for valid_index in self.get_valid_events_index(index, starts, durations): 364 | events_data.append((max(0, float(starts_relative[valid_index])), 365 | min(1, float(stops_relative[valid_index])), 366 | event["label"])) 367 | 368 | return torch.FloatTensor(signal_data), torch.FloatTensor(events_data) 369 | 370 | 371 | class BalancedEventDataset(EventDataset): 372 | """ 373 | Same as EventDataset but with the possibility to choose the probability to get at least 374 | one event when retrieving a window. 375 | 376 | """ 377 | 378 | def __init__(self, 379 | h5_directory, 380 | signals, 381 | window, 382 | fs, 383 | events=None, 384 | records=None, 385 | minimum_overlap=0.5, 386 | transformations=None, 387 | ratio_positive=0.5, 388 | n_jobs=1, 389 | cache_data=True, 390 | ): 391 | super(BalancedEventDataset, self).__init__( 392 | h5_directory=h5_directory, 393 | signals=signals, 394 | events=events, 395 | window=window, 396 | fs=fs, 397 | records=records, 398 | minimum_overlap=minimum_overlap, 399 | transformations=transformations, 400 | n_jobs=n_jobs, 401 | cache_data=cache_data, 402 | ) 403 | self.ratio_positive = ratio_positive 404 | 405 | def __len__(self): 406 | return len(self.index_to_record_event) 407 | 408 | def __getitem__(self, idx): 409 | 410 | signal, events = self.extract_balanced_data( 411 | record=self.index_to_record_event[idx]["record"], 412 | max_index=self.index_to_record_event[idx]["max_index"], 413 | events_indexes=self.index_to_record_event[idx]["events_indexes"], 414 | no_events_indexes=self.index_to_record_event[idx]["no_events_indexes"] 415 | ) 416 | 417 | if self.transformations is not None: 418 | signal = self.transformations(signal) 419 | 420 | return signal, events 421 | 422 | def extract_balanced_data(self, record, max_index, events_indexes, no_events_indexes): 423 | """Extracts an index at random""" 424 | 425 | choice = np.random.choice([0, 1], p=[1 - self.ratio_positive, self.ratio_positive]) 426 | 427 | if choice == 0: 428 | index = no_events_indexes[np.random.randint(len(no_events_indexes))] 429 | else: 430 | index = events_indexes[np.random.randint(len(events_indexes))] 431 | 432 | signal_data, events_data = self.get_sample(record, index) 433 | 434 | return signal_data, events_data 435 | -------------------------------------------------------------------------------- /minimum_example/train_and_evaluate_dosed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Dosed Training and Evaluation\n", 8 | "\n", 9 | "Requirements:\n", 10 | "\n", 11 | "1. You need the data for training as h5 files, to get them you either:\n", 12 | " - Go through `download_and_data_format_explanation.ipynb`\n", 13 | "\n", 14 | " or \n", 15 | "\n", 16 | " - Run `make download_example` you can use optionnal DOWNLOAD_PATH env variable to specify an alternative directory\n", 17 | " \n", 18 | "2. Have dosed installed:\n", 19 | " \n", 20 | " - Run `pip install -e .` from dosed root directory\n", 21 | " or \n", 22 | " - Run `python setup.py develop` from dosed root directory\n", 23 | " \n" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "%matplotlib inline" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import os\n", 42 | "import json\n", 43 | "\n", 44 | "h5_directory = '../data/h5' # adapt if you used a different DOWNLOAD_PATH when running `make download_example`" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "# 1. Train, validation and test dataset creation\n", 52 | "\n", 53 | "## First we select which records we want to train, validate and test on" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "import torch\n", 63 | "import tempfile\n", 64 | "import json\n", 65 | "import random\n", 66 | "\n", 67 | "\n", 68 | "from dosed.utils import Compose\n", 69 | "from dosed.datasets import BalancedEventDataset as dataset\n", 70 | "from dosed.models import DOSED3 as model\n", 71 | "from dosed.datasets import get_train_validation_test\n", 72 | "from dosed.trainers import trainers\n", 73 | "from dosed.preprocessing import GaussianNoise, RescaleNormal, Invert\n", 74 | "\n", 75 | "seed = 2019" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 4, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "Number of records train: 11\n", 88 | "Number of records validation: 5\n", 89 | "Number of records test: 5\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "train, validation, test = get_train_validation_test(h5_directory,\n", 95 | " percent_test=25,\n", 96 | " percent_validation=33,\n", 97 | " seed=seed)\n", 98 | "\n", 99 | "print(\"Number of records train:\", len(train))\n", 100 | "print(\"Number of records validation:\", len(validation))\n", 101 | "print(\"Number of records test:\", len(test))" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## Then we use the dataset class that will be use to generate sample for training and evaluation\n", 109 | "\n", 110 | "- h5_directory: Location of the generic h5 files.\n", 111 | "- signals: the signals from the h5 we want to include together with their normalization\n", 112 | "- events: the evvents from the h5 we want to train on\n", 113 | "- window: Spindles have a duration of ~1 seconds, so we design the samples accordingly by choosing 10 seconds windows\n", 114 | "- ratio_positive: sample within a training batch will have a probability of \"ratio_positive\" to contain at least one spindle \n", 115 | "- n_jobs: number of process used to extract and normalize signals from h5 files.\n", 116 | "- cache_data: cache results of extraction and normalization of signals from h5_file in h5_directory + \"/.cache\" (we strongly recommand to set True)\n", 117 | "\n" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 5, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "window = 10 # window duration in seconds\n", 127 | "ratio_positive = 0.5 # When creating the batch, sample containing at least one spindle will be drawn with that probability\n", 128 | "\n", 129 | "fs = 32\n", 130 | "\n", 131 | "signals = [\n", 132 | " {\n", 133 | " 'h5_path': '/eeg_0',\n", 134 | " 'fs': 64,\n", 135 | " 'processing': {\n", 136 | " \"type\": \"clip_and_normalize\",\n", 137 | " \"args\": {\n", 138 | " \"min_value\": -150,\n", 139 | " \"max_value\": 150,\n", 140 | " }\n", 141 | " }\n", 142 | " },\n", 143 | " {\n", 144 | " 'h5_path': '/eeg_1',\n", 145 | " 'fs': 64,\n", 146 | " 'processing': {\n", 147 | " \"type\": \"clip_and_normalize\",\n", 148 | " \"args\": {\n", 149 | " \"min_value\": -150,\n", 150 | " \"max_value\": 150,\n", 151 | " }\n", 152 | " }\n", 153 | " }\n", 154 | "]\n", 155 | "\n", 156 | "events = [\n", 157 | " {\n", 158 | " \"name\": \"spindle\",\n", 159 | " \"h5_path\": \"spindle\",\n", 160 | " },\n", 161 | "]" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "dataset_parameters = {\n", 171 | " \"h5_directory\": h5_directory,\n", 172 | " \"signals\": signals,\n", 173 | " \"events\": events,\n", 174 | " \"window\": window,\n", 175 | " \"fs\": fs,\n", 176 | " \"ratio_positive\": ratio_positive,\n", 177 | " \"n_jobs\": -1, # Make use of parallel computing to extract and normalize signals from h5\n", 178 | " \"cache_data\": True, # by default will store normalized signals extracted from h5 in h5_directory + \"/.cache\" directory\n", 179 | "}\n", 180 | "\n", 181 | "dataset_validation = dataset(records=validation, **dataset_parameters)\n", 182 | "dataset_test = dataset(records=test, **dataset_parameters)\n", 183 | "\n", 184 | "# for training add data augmentation\n", 185 | "dataset_parameters_train = {\n", 186 | " \"transformations\": Compose([\n", 187 | " GaussianNoise(),\n", 188 | " RescaleNormal(),\n", 189 | " Invert(),\n", 190 | " ])\n", 191 | "}\n", 192 | "dataset_parameters_train.update(dataset_parameters)\n", 193 | "dataset_train = dataset(records=train, **dataset_parameters_train)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "# 2. Create a network\n", 201 | "\n", 202 | "The main parameters for the network are:\n", 203 | " - default event sizes : to choose according to a priori size of the event to detect, here spindles are around 1 second\n", 204 | " - k_max : number of CNN layers\n", 205 | " " 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 7, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "default_event_sizes = [0.7, 1, 1.3]\n", 215 | "k_max = 5\n", 216 | "kernel_size = 5\n", 217 | "probability_dropout = 0.1\n", 218 | "device = torch.device(\"cuda\")" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 8, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "\n", 231 | "Input feature map size: 320\n", 232 | "Input receptive field: 0\n", 233 | "Input size in seconds: 10.0 s\n", 234 | "Input receptive field in seconds: 0.0 s \n", 235 | "\n", 236 | "After layer 1:\n", 237 | "\tFeature map size: 160\n", 238 | "\tReceptive field: 6\n", 239 | "\tReceptive field in seconds: 0.1875 s\n", 240 | "After layer 2:\n", 241 | "\tFeature map size: 80\n", 242 | "\tReceptive field: 16\n", 243 | "\tReceptive field in seconds: 0.5 s\n", 244 | "After layer 3:\n", 245 | "\tFeature map size: 40\n", 246 | "\tReceptive field: 36\n", 247 | "\tReceptive field in seconds: 1.125 s\n", 248 | "After layer 4:\n", 249 | "\tFeature map size: 20\n", 250 | "\tReceptive field: 76\n", 251 | "\tReceptive field in seconds: 2.375 s\n", 252 | "After layer 5:\n", 253 | "\tFeature map size: 10\n", 254 | "\tReceptive field: 156\n", 255 | "\tReceptive field in seconds: 4.875 s\n", 256 | "\n", 257 | "\n" 258 | ] 259 | } 260 | ], 261 | "source": [ 262 | "sampling_frequency = dataset_train.fs\n", 263 | "\n", 264 | "net_parameters = {\n", 265 | " \"detection_parameters\": {\n", 266 | " \"overlap_non_maximum_suppression\": 0.5,\n", 267 | " \"classification_threshold\": 0.7\n", 268 | " },\n", 269 | " \"default_event_sizes\": [\n", 270 | " default_event_size * sampling_frequency\n", 271 | " for default_event_size in default_event_sizes\n", 272 | " ],\n", 273 | " \"k_max\": k_max,\n", 274 | " \"kernel_size\": kernel_size,\n", 275 | " \"pdrop\": probability_dropout,\n", 276 | " \"fs\": sampling_frequency, # just used to print architecture info with right time\n", 277 | " \"input_shape\": dataset_train.input_shape,\n", 278 | " \"number_of_classes\": dataset_train.number_of_classes,\n", 279 | "}\n", 280 | "net = model(**net_parameters)\n", 281 | "net = net.to(device)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "# 3. Train the network\n", 289 | "\n", 290 | "Parameters are\n", 291 | " - learning_rate\n", 292 | " - loss type" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 9, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "optimizer_parameters = {\n", 302 | " \"lr\": 5e-3,\n", 303 | " \"weight_decay\": 1e-8,\n", 304 | "}\n", 305 | "loss_specs = {\n", 306 | " \"type\": \"focal\",\n", 307 | " \"parameters\": {\n", 308 | " \"number_of_classes\": dataset_train.number_of_classes,\n", 309 | " \"device\": device,\n", 310 | " }\n", 311 | "}\n", 312 | "epochs = 20" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 10, 318 | "metadata": {}, 319 | "outputs": [ 320 | { 321 | "name": "stderr", 322 | "output_type": "stream", 323 | "text": [ 324 | "\r", 325 | " 0%| | 0/20 [00:00" 387 | ] 388 | }, 389 | "metadata": { 390 | "needs_background": "light" 391 | }, 392 | "output_type": "display_data" 393 | } 394 | ], 395 | "source": [ 396 | "import matplotlib.pyplot as plt\n", 397 | "import numpy as np\n", 398 | "\n", 399 | "record = dataset_test.records[1]\n", 400 | "\n", 401 | "index_spindle = 30\n", 402 | "window_duration = 5\n", 403 | "\n", 404 | "# retrive spindle at the right index\n", 405 | "spindle_start = float(predictions[record][0][index_spindle][0]) / sampling_frequency\n", 406 | "spindle_end = float(predictions[record][0][index_spindle][1]) / sampling_frequency\n", 407 | "\n", 408 | "# center data window on annotated spindle \n", 409 | "start_window = spindle_start + (spindle_end - spindle_start) / 2 - window_duration\n", 410 | "stop_window = spindle_start + (spindle_end - spindle_start) / 2 + window_duration\n", 411 | "\n", 412 | "# Retrieve EEG data at right index\n", 413 | "index_start = int(start_window * sampling_frequency)\n", 414 | "index_stop = int(stop_window * sampling_frequency)\n", 415 | "y = dataset_test.signals[record][\"data\"][0][index_start:index_stop]\n", 416 | "\n", 417 | "# Build corresponding time support\n", 418 | "t = start_window + np.cumsum(np.ones(index_stop - index_start) * 1 / sampling_frequency)\n", 419 | "\n", 420 | "plt.figure(figsize=(16, 5))\n", 421 | "plt.plot(t, y)\n", 422 | "plt.axvline(spindle_end)\n", 423 | "plt.axvline(spindle_start)\n", 424 | "plt.ylim([-1, 1])\n", 425 | "plt.show()" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": null, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [] 434 | } 435 | ], 436 | "metadata": { 437 | "kernelspec": { 438 | "display_name": "pytorch", 439 | "language": "python", 440 | "name": "pytorch" 441 | }, 442 | "language_info": { 443 | "codemirror_mode": { 444 | "name": "ipython", 445 | "version": 3 446 | }, 447 | "file_extension": ".py", 448 | "mimetype": "text/x-python", 449 | "name": "python", 450 | "nbconvert_exporter": "python", 451 | "pygments_lexer": "ipython3", 452 | "version": "3.6.8" 453 | } 454 | }, 455 | "nbformat": 4, 456 | "nbformat_minor": 1 457 | } 458 | --------------------------------------------------------------------------------