├── .gitignore ├── LICENSE ├── README.md ├── config.yaml ├── data └── .gitkeep ├── env.yaml ├── experiments └── .gitkeep ├── pickle_to_csv.py ├── predict.py ├── src ├── base │ ├── __init__.py │ ├── base_data_loader.py │ ├── base_model.py │ └── base_trainer.py ├── configs │ ├── abstract.yaml │ ├── data_base.json │ ├── data_cfs.json │ ├── data_isruc.json │ ├── data_mros.json │ ├── data_shhs.json │ ├── data_ssc.json │ ├── data_wsc.json │ ├── exp01-hu000.yaml │ ├── exp01-hu1024-sgd-clr-10min.yaml │ ├── exp01-hu1024-sgd-clr-2min.yaml │ ├── exp01-hu1024-sgd-clr-3min.yaml │ ├── exp01-hu1024-sgd-clr-4min.yaml │ ├── exp01-hu1024-sgd-clr-5min.yaml │ ├── exp01-hu1024-sgd-clr.yaml │ ├── exp01-hu1024.yaml │ ├── exp01-hu128.yaml │ ├── exp01-hu2048.yaml │ ├── exp01-hu256.yaml │ ├── exp01-hu4096.yaml │ ├── exp01-hu512.yaml │ ├── exp01-hu64.yaml │ ├── exp02-frac00025.yaml │ ├── exp02-frac0005.yaml │ ├── exp02-frac001.yaml │ ├── exp02-frac005.yaml │ ├── exp02-frac010.yaml │ ├── exp02-frac025.yaml │ ├── exp02-frac050.yaml │ ├── exp02-frac075.yaml │ ├── exp02-frac100.yaml │ ├── exp02-isruc.yaml │ ├── exp02-loci-isruc-sgd-clr.yaml │ ├── exp02-loci-isruc-wd.yaml │ ├── exp02-loci-isruc.yaml │ ├── exp02-loci-mros-wd.yaml │ ├── exp02-loci-mros.yaml │ ├── exp02-loci-shhs-wd.yaml │ ├── exp02-loci-shhs.yaml │ ├── exp02-loci-ssc-wd.yaml │ ├── exp02-loci-ssc.yaml │ ├── exp02-loci-wsc-sgd-clr.yaml │ ├── exp02-loci-wsc-wd.yaml │ ├── exp02-loci-wsc.yaml │ ├── exp02-loco-isruc.yaml │ ├── exp02-loco-mros.yaml │ ├── exp02-loco-shhs.yaml │ ├── exp02-loco-ssc.yaml │ ├── exp02-loco-wsc.yaml │ ├── exp02-mros.yaml │ ├── exp02-wsc.yaml │ ├── exp03-frac100-wd.yaml │ ├── exp03-frac100.yaml │ ├── exp03-isruc.yaml │ ├── exp03-mros.yaml │ ├── exp03-shhs.yaml │ ├── exp03-ssc.yaml │ ├── exp03-wsc.yaml │ ├── exp04-isruc-mros-shhs-ssc.yaml │ ├── exp04-isruc-mros-shhs-wsc.yaml │ ├── exp04-isruc-mros-shhs.yaml │ ├── exp04-isruc-mros-ssc-wsc.yaml │ ├── exp04-isruc-mros-ssc.yaml │ ├── exp04-isruc-mros-wsc.yaml │ ├── exp04-isruc-mros.yaml │ ├── exp04-isruc-shhs-ssc-wsc.yaml │ ├── exp04-isruc-shhs-ssc.yaml │ ├── exp04-isruc-shhs-wsc.yaml │ ├── exp04-isruc-shhs.yaml │ ├── exp04-isruc-ssc-wsc.yaml │ ├── exp04-isruc-ssc.yaml │ ├── exp04-isruc-wsc.yaml │ ├── exp04-mros-shhs-ssc-wsc.yaml │ ├── exp04-mros-shhs-ssc.yaml │ ├── exp04-mros-shhs-wsc.yaml │ ├── exp04-mros-shhs.yaml │ ├── exp04-mros-ssc-wsc.yaml │ ├── exp04-mros-ssc.yaml │ ├── exp04-mros-wsc.yaml │ ├── exp04-shhs-ssc-wsc.yaml │ ├── exp04-shhs-ssc.yaml │ ├── exp04-shhs-wsc.yaml │ ├── exp04-ssc-wsc.yaml │ ├── isruc.yaml │ ├── signal_labels │ │ ├── cfs.json │ │ ├── isruc.json │ │ ├── mros.json │ │ ├── shhs.json │ │ ├── ssc.json │ │ └── wsc.json │ ├── test-dataset.yaml │ ├── test-rnn_model.yaml │ └── test.yaml ├── data │ ├── .gitkeep │ └── generate_cohort_files.py ├── data_loader │ ├── __init__.py │ ├── balanced_dataset.py │ ├── data_loaders.py │ └── dataset.py ├── model │ ├── __init__.py │ ├── losses.py │ ├── metrics.py │ └── rnn_model.py ├── trainer │ ├── __init__.py │ └── trainer.py └── utils │ ├── __init__.py │ ├── channel_label_identifier.py │ ├── config.py │ ├── ensure_dir.py │ ├── factory.py │ ├── logger.py │ ├── parallel_bar.py │ ├── parseXmlEdfp.py │ ├── pickle_reader.py │ ├── segmentation.py │ └── visualization.py ├── train.py └── trained_models └── best_weights.pth /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Alexander Neergaard Olesen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deep-sleep-pytorch 2 | 3 | 4 | 5 | ## Requirements 6 | Principal requirements are Python 3.7.3, PyTorch 1.1.0, CUDA 10 and cuDNN 7.5.1. 7 | An environment YAML has been provided with other packages required for this repository, which can be installed by running `conda env create -f env.yaml`. 8 | 9 | ## Preliminaries 10 | In the following, `COHORT_NAME` will designate the name of a custom cohort. 11 | 12 | ## Data preparation 13 | 14 | ### Set up channel label JSON file 15 | Use the `utils/channel_label_identifier.py` tool by running 16 | ``` 17 | python src/utils/channel_label_identifier.py src/configs/signal_labels/.json C3 C4 A1 A2 EOGL EOGR LChin RChin EMG 18 | ``` 19 | This will create a JSON file containing key-value pairs to map the desired electrode labels shown above with the electrode configurations available in the data. 20 | 21 | ### Setup data processing configuration file 22 | 1. Copy the contents of `src.configs.data_base.json` to another file `data_.json` in the same directory. 23 | 2. Insert the name of your test data `COHORT_NAME` in line 3 of the file, and change the `edf` and `stage` paths to point to the location of your EDFs and hypnograms (this can be the same directory). 24 | 3. (optional) Change the output directory to a custom location 25 | 26 | ### Run data pipeline to generate H5 files 27 | 1. Modify the code that returns a list of hypnograms (around line 320) for your specific use-case. 28 | 2. Add a routine to extract subject ID (`list_subjectID`) from filenames around line 363. 29 | 3. Add a routine to extract hypnogram in the `process_file()` function around line 118. The output shape of the `hypnogram` variable should be `(N,)` (a 1D array), where `N` is the number of 30 s epochs. 30 | 4. If you have lights-off/on information, you can include a routine in `process_file()` around line 266. 31 | 5. If you have non-AASM standard sleep scoring in your hypnograms (ie. values outside of {W, N1, N2, N3, R} --> {0, 1, 2, 3, 4}), you can add a routine around line 282. 32 | 33 | Now run the data generation pipeline using 34 | ``` 35 | python -m src.data.generate_cohort_files -c data_.json 36 | ``` 37 | this will generate the H5 files containing the EDF/hypnogram data, and a CSV file containing an overview over the used files. 38 | 39 | ## Inference on new data 40 | ### Set up configuration file 41 | 1. Change the `data.test` parameter `config.yaml` corresponding to your `COHORT_NAME` variable. 42 | 2. Change the `exp.name` parameter to your `EXPERIMENT_NAME`. 43 | 3. Change the `trainer.log_dir` parameter to `experiments/`. 44 | 4. (Optional) If using more than 1 GPU, change the `trainer.n_gpu` parameter to the number of GPUs. 45 | 5. Change data.data_dir if output directory was set to custom location 46 | 47 | ### Run script 48 | ``` 49 | python predict.py -c config.yaml -r trained_models/best_weights.pth 50 | ``` 51 | 52 | ## Citation 53 | A. N. Olesen, P. J. Jennum, E. Mignot, H. B. D. Sorensen. Automatic sleep stage classification with deep residual networks in a mixed-cohort setting. *Sleep*, Volume 44, Issue 1, January 2021, zsaa161. [DOI:10.1093/sleep/zsaa161](https://doi.org/10.1093/sleep/zsaa161) 54 | ``` 55 | @article{Olesen2020, 56 | author = {Olesen, Alexander Neergaard and {J{\o}rgen Jennum}, Poul and Mignot, Emmanuel and Sorensen, Helge Bjarup Dissing}, 57 | doi = {10.1093/sleep/zsaa161}, 58 | journal = {Sleep}, 59 | number = {1}, 60 | pages = {zsaa161}, 61 | title = {{Automatic sleep stage classification with deep residual networks in a mixed-cohort setting}}, 62 | volume = {44}, 63 | year = {2021} 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | data_loader: 2 | batch_size: 3 | eval: 256 4 | test: 256 5 | train: 256 6 | data: 7 | eval: 8 | test: 9 | - - cfs 10 | - test 11 | train: 12 | data_dir: ./data/processed 13 | import: src.data_loader.dataset.MultiCohortDataset 14 | modalities: 15 | - eeg 16 | - eog 17 | - emg 18 | num_classes: 5 19 | segment_length: 300 20 | train_fraction: null 21 | exp: 22 | name: my_config 23 | loss: 24 | import: src.model.losses.temporal_crossentropy_loss 25 | lr_scheduler: 26 | args: 27 | base_lr: 0.1 28 | max_lr: 0.5 29 | mode: triangular 30 | step_size_up: 500 31 | import: torch.optim.lr_scheduler.CyclicLR 32 | metrics: 33 | - overall_accuracy 34 | - balanced_accuracy 35 | - kappa 36 | network: 37 | filter_base: 4 38 | import: src.model.rnn_model.RnnModel 39 | kernel_size: 3 40 | max_pooling: 2 41 | num_blocks: 7 42 | rnn_bidirectional: true 43 | rnn_num_layers: 1 44 | rnn_num_units: 1024 45 | optimizer: 46 | args: 47 | lr: 0.1 48 | momentum: 0.9 49 | nesterov: true 50 | import: torch.optim.SGD 51 | trainer: 52 | epochs: 20 53 | log_dir: experiments/my_experiment 54 | monitor: min val_loss 55 | n_gpu: 1 56 | num_workers: 128 57 | save_dir: experiments 58 | save_freq: 1 59 | tensorboardX: false 60 | verbosity: 2 61 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neergaard/deep-sleep-pytorch/db5d6f59369baacabeac956201c3c6de09961945/data/.gitkeep -------------------------------------------------------------------------------- /experiments/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neergaard/deep-sleep-pytorch/db5d6f59369baacabeac956201c3c6de09961945/experiments/.gitkeep -------------------------------------------------------------------------------- /pickle_to_csv.py: -------------------------------------------------------------------------------- 1 | """ Conversion of pickle files to comma-separated-values (csv) file. 2 | You can supply either a directory containing .pkl files via the '-d/--data_dir' flag, or 3 | a single .pkl file using the '--file' flag. 4 | 5 | Alexander Neergaard Zahid, 2021. 6 | """ 7 | 8 | import argparse 9 | import time 10 | from pathlib import Path 11 | from tqdm import tqdm 12 | 13 | import pandas as pd 14 | 15 | from src.utils.pickle_reader import read_pickle 16 | 17 | 18 | def convert_to_csv(filepath): 19 | predictions, targets = read_pickle(filepath) 20 | df = pd.concat( 21 | [pd.DataFrame(targets, columns=["Hypnogram"]), pd.DataFrame(predictions, columns=["W", "N1", "N2", "N3", "R"])], 22 | axis=1, 23 | ) 24 | df.to_csv(filepath.parent / (filepath.stem + ".csv"), index=False) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser(description=__doc__) 29 | parser.add_argument("-d", "--data_dir", type=str, help="Path to directory containing .pkl files to convert.") 30 | parser.add_argument("-f", "--file", type=str, help="Path to specific file to convert.") 31 | args = parser.parse_args() 32 | 33 | assert (args.data_dir is not None and args.file is None) or ( 34 | args.data_dir is None and args.file is not None 35 | ), f"Specify either a data directory or a file, received data_dir={args.data_dir} and file={args.file}" 36 | 37 | if args.data_dir is not None: 38 | data_dir = Path(args.data_dir) 39 | list_files = sorted(list(data_dir.glob("**/*.pkl"))) 40 | # N = len(list_files) 41 | N = 10 42 | list_files = list_files[:N] 43 | if N == 0: 44 | print(f"No .pkl files found in directory!") 45 | else: 46 | print(f"Starting conversion of {N} .pkl files...") 47 | bar = tqdm(list_files) 48 | start = time.time() 49 | for filepath in bar: 50 | bar.set_description(filepath.stem) 51 | convert_to_csv(filepath) 52 | end = time.time() 53 | print(f"Finished, {N} files converted in {end-start} seconds.") 54 | elif args.file is not None: 55 | N = 1 56 | filepath = Path(args.file) 57 | print(f"Converting file {filepath.stem}...") 58 | start = time.time() 59 | convert_to_csv(filepath) 60 | end = time.time() 61 | print(f"Converted {filepath.stem} in {end-start} seconds.") 62 | -------------------------------------------------------------------------------- /src/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neergaard/deep-sleep-pytorch/db5d6f59369baacabeac956201c3c6de09961945/src/base/__init__.py -------------------------------------------------------------------------------- /src/base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | 12 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 13 | self.validation_split = validation_split 14 | self.shuffle = shuffle 15 | 16 | self.batch_idx = 0 17 | self.n_samples = len(dataset) 18 | 19 | self.sampler, self.valid_sampler = self._split_sampler( 20 | self.validation_split) 21 | 22 | self.init_kwargs = { 23 | 'dataset': dataset, 24 | 'batch_size': batch_size, 25 | 'shuffle': self.shuffle, 26 | 'collate_fn': collate_fn, 27 | 'num_workers': num_workers 28 | } 29 | super(BaseDataLoader, self).__init__( 30 | sampler=self.sampler, **self.init_kwargs) 31 | 32 | def _split_sampler(self, split): 33 | if split == 0.0: 34 | return None, None 35 | 36 | idx_full = np.arange(self.n_samples) 37 | 38 | np.random.seed(0) 39 | np.random.shuffle(idx_full) 40 | 41 | len_valid = int(self.n_samples * split) 42 | 43 | valid_idx = idx_full[0:len_valid] 44 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 45 | 46 | train_sampler = SubsetRandomSampler(train_idx) 47 | valid_sampler = SubsetRandomSampler(valid_idx) 48 | 49 | # turn off shuffle option which is mutually exclusive with sampler 50 | self.shuffle = False 51 | self.n_samples = len(train_idx) 52 | 53 | return train_sampler, valid_sampler 54 | 55 | def split_validation(self): 56 | if self.valid_sampler is None: 57 | return None 58 | else: 59 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 60 | -------------------------------------------------------------------------------- /src/base/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | 7 | class BaseModel(nn.Module): 8 | """ 9 | Base class for all models 10 | """ 11 | 12 | def __init__(self): 13 | super(BaseModel, self).__init__() 14 | self.logger = logging.getLogger(self.__class__.__name__) 15 | 16 | def forward(self, *input): 17 | """ 18 | Forward pass logic 19 | 20 | :return: Model output 21 | """ 22 | raise NotImplementedError 23 | 24 | def __str__(self): 25 | """ 26 | Model prints with number of trainable parameters 27 | """ 28 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 29 | params = sum([np.prod(p.size()) for p in model_parameters]) 30 | return super(BaseModel, self).__str__() + '\nTrainable parameters: {}'.format(params) 31 | # print(super(BaseModel, self)) 32 | -------------------------------------------------------------------------------- /src/configs/abstract.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: abstract 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 64 7 | eval: 64 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | num_classes: 5 31 | segment_length: 300 # Length in seconds 32 | network: 33 | import: src.model.rnn_model.RnnModel 34 | filter_base: 4 35 | kernel_size: 3 36 | max_pooling: 2 37 | num_blocks: 7 38 | rnn_bidirectional: true 39 | rnn_num_layers: 1 40 | rnn_num_units: 41 | loss: 42 | import: src.model.losses.nll_loss 43 | metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 44 | optimizer: 45 | import: Adam 46 | args: 47 | lr: 0.001 48 | weight_decay: 0 49 | amsgrad: true 50 | lr_scheduler: 51 | import: ReduceLROnPlateau 52 | mode: 'min' 53 | factor: 2 54 | patience: 5 55 | verbose: true 56 | trainer: 57 | early_stop: 10 58 | epochs: 100 # Number of training epochs 59 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 60 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 61 | n_gpu: 4 62 | num_workers: 128 63 | save_dir: experiments 64 | save_freq: 1 # save checkpoints every save_freq epochs 65 | tensorboardX: false # Enable tensorboardX visualization support 66 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 67 | -------------------------------------------------------------------------------- /src/configs/data_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "COHORTS": { 3 | "": { 4 | "edf": "./data/raw/EDF_PATH", 5 | "stage": "./data/raw/HYPNOGRAM_PATH" 6 | } 7 | }, 8 | "COHORT_OVERVIEW_FILE": [], 9 | "FILTERS": { 10 | "btype": { 11 | "eeg": "bandpass", 12 | "emg": "high", 13 | "eog": "bandpass" 14 | }, 15 | "fc": { 16 | "eeg": [ 17 | 0.5, 18 | 35 19 | ], 20 | "emg": [ 21 | 10 22 | ], 23 | "eog": [ 24 | 0.5, 25 | 35 26 | ] 27 | }, 28 | "fs_resampling": 128, 29 | "order": { 30 | "eeg": 4, 31 | "emg": 4, 32 | "eog": 4 33 | } 34 | }, 35 | "OUTPUT_DIRECTORY": "./data/processed", 36 | "PARTITIONS": { 37 | "EVAL": 0.0, 38 | "TRAIN": 0.0 39 | }, 40 | "SEGMENTATION": { 41 | "EPOCH_LENGTH_SEC": 30, 42 | "SEGMENT_LENGTH": { 43 | "eeg": 30, 44 | "emg": 30, 45 | "eog": 30 46 | }, 47 | "SEGMENT_OVERLAP": { 48 | "eeg": 0, 49 | "emg": 0, 50 | "eog": 0 51 | }, 52 | "SEQUENCE_DURATION_MIN": 5 53 | }, 54 | "SUBSETS": [ 55 | "Test" 56 | ] 57 | } 58 | -------------------------------------------------------------------------------- /src/configs/data_cfs.json: -------------------------------------------------------------------------------- 1 | { 2 | "COHORTS": { 3 | "cfs": { 4 | "edf": "./data/raw/cfs", 5 | "stage": "./data/raw/cfs" 6 | } 7 | }, 8 | "COHORT_OVERVIEW_FILE": [], 9 | "FILTERS": { 10 | "btype": { 11 | "eeg": "bandpass", 12 | "emg": "high", 13 | "eog": "bandpass" 14 | }, 15 | "fc": { 16 | "eeg": [ 17 | 0.5, 18 | 35 19 | ], 20 | "emg": [ 21 | 10 22 | ], 23 | "eog": [ 24 | 0.5, 25 | 35 26 | ] 27 | }, 28 | "fs_resampling": 128, 29 | "order": { 30 | "eeg": 4, 31 | "emg": 4, 32 | "eog": 4 33 | } 34 | }, 35 | "OUTPUT_DIRECTORY": "./data/processed", 36 | "PARTITIONS": { 37 | "EVAL": 0.0, 38 | "TRAIN": 0.0 39 | }, 40 | "SEGMENTATION": { 41 | "EPOCH_LENGTH_SEC": 30, 42 | "SEGMENT_LENGTH": { 43 | "eeg": 30, 44 | "emg": 30, 45 | "eog": 30 46 | }, 47 | "SEGMENT_OVERLAP": { 48 | "eeg": 0, 49 | "emg": 0, 50 | "eog": 0 51 | }, 52 | "SEQUENCE_DURATION_MIN": 5 53 | }, 54 | "SUBSETS": [ 55 | "Test" 56 | ] 57 | } 58 | -------------------------------------------------------------------------------- /src/configs/data_isruc.json: -------------------------------------------------------------------------------- 1 | { 2 | "COHORTS": { 3 | "isruc": { 4 | "edf": "./data/raw/isruc", 5 | "stage": "./data/raw/isruc" 6 | } 7 | }, 8 | "COHORT_OVERVIEW_FILE": [], 9 | "FILTERS": { 10 | "btype": { 11 | "eeg": "bandpass", 12 | "emg": "high", 13 | "eog": "bandpass" 14 | }, 15 | "fc": { 16 | "eeg": [ 17 | 0.5, 18 | 35 19 | ], 20 | "emg": [ 21 | 10 22 | ], 23 | "eog": [ 24 | 0.5, 25 | 35 26 | ] 27 | }, 28 | "fs_resampling": 128, 29 | "order": { 30 | "eeg": 4, 31 | "emg": 4, 32 | "eog": 4 33 | } 34 | }, 35 | "OUTPUT_DIRECTORY": "./data/processed_oak", 36 | "PARTITIONS": { 37 | "EVAL": 0.3, 38 | "TRAIN": 0.3 39 | }, 40 | "SEGMENTATION": { 41 | "EPOCH_LENGTH_SEC": 30, 42 | "SEGMENT_LENGTH": { 43 | "eeg": 30, 44 | "emg": 30, 45 | "eog": 30 46 | }, 47 | "SEGMENT_OVERLAP": { 48 | "eeg": 0, 49 | "emg": 0, 50 | "eog": 0 51 | }, 52 | "SEQUENCE_DURATION_MIN": 5 53 | }, 54 | "SUBSETS": [ 55 | "Train", 56 | "Eval", 57 | "Test" 58 | ] 59 | } -------------------------------------------------------------------------------- /src/configs/data_mros.json: -------------------------------------------------------------------------------- 1 | { 2 | "COHORTS": { 3 | "mros": { 4 | "edf": "./data/raw/mros", 5 | "stage": "./data/raw/mros" 6 | } 7 | }, 8 | "COHORT_OVERVIEW_FILE": [], 9 | "FILTERS": { 10 | "btype": { 11 | "eeg": "bandpass", 12 | "emg": "high", 13 | "eog": "bandpass" 14 | }, 15 | "fc": { 16 | "eeg": [ 17 | 0.5, 18 | 35 19 | ], 20 | "emg": [ 21 | 10 22 | ], 23 | "eog": [ 24 | 0.5, 25 | 35 26 | ] 27 | }, 28 | "fs_resampling": 128, 29 | "order": { 30 | "eeg": 4, 31 | "emg": 4, 32 | "eog": 4 33 | } 34 | }, 35 | "OUTPUT_DIRECTORY": "./data/processed_oak", 36 | "PARTITIONS": { 37 | "EVAL": 0.025, 38 | "TRAIN": 0.875 39 | }, 40 | "SEGMENTATION": { 41 | "EPOCH_LENGTH_SEC": 30, 42 | "SEGMENT_LENGTH": { 43 | "eeg": 30, 44 | "emg": 30, 45 | "eog": 30 46 | }, 47 | "SEGMENT_OVERLAP": { 48 | "eeg": 0, 49 | "emg": 0, 50 | "eog": 0 51 | }, 52 | "SEQUENCE_DURATION_MIN": 5 53 | }, 54 | "SUBSETS": [ 55 | "Train", 56 | "Eval", 57 | "Test" 58 | ] 59 | } -------------------------------------------------------------------------------- /src/configs/data_shhs.json: -------------------------------------------------------------------------------- 1 | { 2 | "COHORTS": { 3 | "shhs": { 4 | "edf": "./data/raw/shhs", 5 | "stage": "./data/raw/shhs" 6 | } 7 | }, 8 | "COHORT_OVERVIEW_FILE": [], 9 | "FILTERS": { 10 | "btype": { 11 | "eeg": "bandpass", 12 | "emg": "high", 13 | "eog": "bandpass" 14 | }, 15 | "fc": { 16 | "eeg": [ 17 | 0.5, 18 | 35 19 | ], 20 | "emg": [ 21 | 10 22 | ], 23 | "eog": [ 24 | 0.5, 25 | 35 26 | ] 27 | }, 28 | "fs_resampling": 128, 29 | "order": { 30 | "eeg": 4, 31 | "emg": 4, 32 | "eog": 4 33 | } 34 | }, 35 | "OUTPUT_DIRECTORY": "./data/processed_oak", 36 | "PARTITIONS": { 37 | "EVAL": 0.025, 38 | "TRAIN": 0.875 39 | }, 40 | "SEGMENTATION": { 41 | "EPOCH_LENGTH_SEC": 30, 42 | "SEGMENT_LENGTH": { 43 | "eeg": 30, 44 | "emg": 30, 45 | "eog": 30 46 | }, 47 | "SEGMENT_OVERLAP": { 48 | "eeg": 0, 49 | "emg": 0, 50 | "eog": 0 51 | }, 52 | "SEQUENCE_DURATION_MIN": 5 53 | }, 54 | "SUBSETS": [ 55 | "Train", 56 | "Eval", 57 | "Test" 58 | ] 59 | } -------------------------------------------------------------------------------- /src/configs/data_ssc.json: -------------------------------------------------------------------------------- 1 | { 2 | "COHORTS": { 3 | "ssc": { 4 | "edf": "./data/raw/ssc", 5 | "stage": "./data/raw/ssc" 6 | } 7 | }, 8 | "COHORT_OVERVIEW_FILE": [], 9 | "SUBSETS": ["Train", "Eval", "Test"], 10 | "FILTERS": { 11 | "btype": { 12 | "eeg": "bandpass", 13 | "emg": "high", 14 | "eog": "bandpass" 15 | }, 16 | "fc": { 17 | "eeg": [0.5, 35], 18 | "emg": [10], 19 | "eog": [0.5, 35] 20 | }, 21 | "fs_resampling": 128, 22 | "order": { 23 | "eeg": 4, 24 | "emg": 4, 25 | "eog": 4 26 | } 27 | }, 28 | "OUTPUT_DIRECTORY": "./data/processed_oak", 29 | "PARTITIONS": { 30 | "EVAL": 0.025, 31 | "TRAIN": 0.875 32 | }, 33 | "SEGMENTATION": { 34 | "EPOCH_LENGTH_SEC": 30, 35 | "SEGMENT_LENGTH": { 36 | "eeg": 30, 37 | "emg": 30, 38 | "eog": 30 39 | }, 40 | "SEGMENT_OVERLAP": { 41 | "eeg": 0, 42 | "emg": 0, 43 | "eog": 0 44 | }, 45 | "SEQUENCE_DURATION_MIN": 5 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/configs/data_wsc.json: -------------------------------------------------------------------------------- 1 | { 2 | "COHORTS": { 3 | "wsc": { 4 | "edf": "./data/raw/wsc", 5 | "stage": "./data/raw/wsc" 6 | } 7 | }, 8 | "COHORT_OVERVIEW_FILE": [], 9 | "FILTERS": { 10 | "btype": { 11 | "eeg": "bandpass", 12 | "emg": "high", 13 | "eog": "bandpass" 14 | }, 15 | "fc": { 16 | "eeg": [0.5, 35], 17 | "emg": [10], 18 | "eog": [0.5, 35] 19 | }, 20 | "fs_resampling": 128, 21 | "order": { 22 | "eeg": 4, 23 | "emg": 4, 24 | "eog": 4 25 | } 26 | }, 27 | "OUTPUT_DIRECTORY": "./data/processed_oak", 28 | "PARTITIONS": { 29 | "EVAL": 0.025, 30 | "TRAIN": 0.875 31 | }, 32 | "SEGMENTATION": { 33 | "EPOCH_LENGTH_SEC": 30, 34 | "SEGMENT_LENGTH": { 35 | "eeg": 30, 36 | "emg": 30, 37 | "eog": 30 38 | }, 39 | "SEGMENT_OVERLAP": { 40 | "eeg": 0, 41 | "emg": 0, 42 | "eog": 0 43 | }, 44 | "SEQUENCE_DURATION_MIN": 5 45 | }, 46 | "SUBSETS": ["Train", "Eval", "Test"] 47 | } 48 | -------------------------------------------------------------------------------- /src/configs/exp01-hu000.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu000 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 0 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.Adam 48 | args: 49 | lr: 0.0001 50 | weight_decay: 0 51 | amsgrad: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 54 | args: 55 | mode: 'min' 56 | factor: 0.5 57 | patience: 5 58 | verbose: true 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu1024-sgd-clr-10min.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu1024-sgd-clr-10min 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 128 7 | eval: 128 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 600 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu1024-sgd-clr-2min.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu1024-sgd-clr-2min 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 120 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu1024-sgd-clr-3min.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu1024-sgd-clr-3min 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 180 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu1024-sgd-clr-4min.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu1024-sgd-clr-4min 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 240 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu1024-sgd-clr-5min.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu1024-sgd-clr-5min 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu1024-sgd-clr.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu1024-sgd-cycliclr 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu1024.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu1024 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.Adam 48 | args: 49 | lr: 0.0001 50 | weight_decay: 0 51 | amsgrad: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 54 | args: 55 | mode: 'min' 56 | factor: 0.5 57 | patience: 5 58 | verbose: true 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu128.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu128 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 128 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.Adam 48 | args: 49 | lr: 0.0001 50 | weight_decay: 0 51 | amsgrad: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 54 | args: 55 | mode: 'min' 56 | factor: 0.5 57 | patience: 5 58 | verbose: true 59 | trainer: 60 | # early_stop: 10 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu2048.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu2048 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 2048 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.Adam 48 | args: 49 | lr: 0.0001 50 | weight_decay: 0 51 | amsgrad: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 54 | args: 55 | mode: 'min' 56 | factor: 0.5 57 | patience: 5 58 | verbose: true 59 | trainer: 60 | # early_stop: 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu256.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu256 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 256 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.Adam 48 | args: 49 | lr: 0.0001 50 | weight_decay: 0 51 | amsgrad: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 54 | args: 55 | mode: 'min' 56 | factor: 0.5 57 | patience: 5 58 | verbose: true 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu4096.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu4096 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 4096 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.Adam 48 | args: 49 | lr: 0.0001 50 | weight_decay: 0 51 | amsgrad: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 54 | args: 55 | mode: 'min' 56 | factor: 0.5 57 | patience: 5 58 | verbose: true 59 | trainer: 60 | early_stop: 11 61 | epochs: 1000 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu512.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu512 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 512 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.Adam 48 | args: 49 | lr: 0.0001 50 | weight_decay: 0 51 | amsgrad: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 54 | args: 55 | mode: 'min' 56 | factor: 0.5 57 | patience: 5 58 | verbose: true 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp01-hu64.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp01-hu64 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 580 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 64 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.Adam 48 | args: 49 | lr: 0.0001 50 | weight_decay: 0 51 | amsgrad: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 54 | args: 55 | mode: 'min' 56 | factor: 0.5 57 | patience: 5 58 | verbose: true 59 | trainer: 60 | # early_stop: 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 1 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-frac00025.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-frac00025 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 0.0025 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 100 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-frac0005.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-frac0005 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 0.005 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 100 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-frac001.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-frac001 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 0.01 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 100 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-frac005.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-frac005 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 0.05 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.25 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-frac010.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-frac010 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 0.1 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-frac025.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-frac025 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 0.25 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-frac050.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-frac050 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 0.50 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.5 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-frac075.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-frac075 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 0.75 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.1 56 | max_lr: 0.25 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 50 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-frac100.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-frac100 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.SGD 48 | args: 49 | lr: 0.1 50 | momentum: 0.9 51 | nesterov: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.CyclicLR 54 | args: 55 | base_lr: 0.05 56 | max_lr: 0.15 57 | step_size_up: 500 58 | mode: 'triangular' 59 | trainer: 60 | # early_stop: 11 61 | epochs: 40 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-isruc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-isruc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | # - [mros, test] 25 | # - [shhs, test] 26 | # - [ssc, test] 27 | # - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 0.15 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 256 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: torch.optim.Adam 48 | args: 49 | lr: 0.0001 50 | weight_decay: 0 51 | amsgrad: true 52 | lr_scheduler: 53 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 54 | args: 55 | mode: 'min' 56 | factor: 0.5 57 | patience: 5 58 | verbose: true 59 | trainer: 60 | early_stop: 10 61 | epochs: 1000 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 128 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-isruc-sgd-clr.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-isruc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.1 64 | max_lr: 0.5 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 10 69 | epochs: 100 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-isruc-wd.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-isruc-wd 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | - [mros, train] 29 | - [shhs, train] 30 | - [ssc, train] 31 | - [wsc, train] 32 | - [mros, eval] 33 | - [shhs, eval] 34 | - [ssc, eval] 35 | - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.Adam 56 | args: 57 | lr: 0.0001 58 | weight_decay: 0.0002 59 | amsgrad: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 62 | args: 63 | mode: 'min' 64 | factor: 0.5 65 | patience: 5 66 | verbose: true 67 | trainer: 68 | # early_stop: 10 69 | epochs: 100 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-isruc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-isruc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | - [mros, train] 29 | - [shhs, train] 30 | - [ssc, train] 31 | - [wsc, train] 32 | - [mros, eval] 33 | - [shhs, eval] 34 | - [ssc, eval] 35 | - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.Adam 56 | args: 57 | lr: 0.0001 58 | weight_decay: 0 59 | amsgrad: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 62 | args: 63 | mode: 'min' 64 | factor: 0.5 65 | patience: 5 66 | verbose: true 67 | trainer: 68 | # early_stop: 10 69 | epochs: 100 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-mros-wd.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-mros-wd 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [mros, test] 24 | - [isruc, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | - [isruc, train] 29 | - [shhs, train] 30 | - [ssc, train] 31 | - [wsc, train] 32 | - [isruc, eval] 33 | - [shhs, eval] 34 | - [ssc, eval] 35 | - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.Adam 56 | args: 57 | lr: 0.0001 58 | weight_decay: 0.0002 59 | amsgrad: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 62 | args: 63 | mode: 'min' 64 | factor: 0.5 65 | patience: 5 66 | verbose: true 67 | trainer: 68 | # early_stop: 10 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-mros.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-mros 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [mros, test] 24 | - [isruc, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | - [isruc, train] 29 | - [shhs, train] 30 | - [ssc, train] 31 | - [wsc, train] 32 | - [isruc, eval] 33 | - [shhs, eval] 34 | - [ssc, eval] 35 | - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.Adam 56 | args: 57 | lr: 0.0001 58 | weight_decay: 0 59 | amsgrad: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 62 | args: 63 | mode: 'min' 64 | factor: 0.5 65 | patience: 5 66 | verbose: true 67 | trainer: 68 | # early_stop: 10 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-shhs-wd.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-shhs-wd 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [shhs, test] 24 | - [isruc, test] 25 | - [mros, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | - [isruc, train] 29 | - [mros, train] 30 | - [ssc, train] 31 | - [wsc, train] 32 | - [isruc, eval] 33 | - [mros, eval] 34 | - [ssc, eval] 35 | - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.Adam 56 | args: 57 | lr: 0.0001 58 | weight_decay: 0.0002 59 | amsgrad: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 62 | args: 63 | mode: 'min' 64 | factor: 0.5 65 | patience: 5 66 | verbose: true 67 | trainer: 68 | # early_stop: 10 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-shhs.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-shhs 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [shhs, test] 24 | - [isruc, test] 25 | - [mros, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | - [isruc, train] 29 | - [mros, train] 30 | - [ssc, train] 31 | - [wsc, train] 32 | - [isruc, eval] 33 | - [mros, eval] 34 | - [ssc, eval] 35 | - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.Adam 56 | args: 57 | lr: 0.0001 58 | weight_decay: 0 59 | amsgrad: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 62 | args: 63 | mode: 'min' 64 | factor: 0.5 65 | patience: 5 66 | verbose: true 67 | trainer: 68 | # early_stop: 10 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-ssc-wd.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-ssc-wd 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [ssc, test] 24 | - [isruc, test] 25 | - [mros, test] 26 | - [shhs, test] 27 | - [wsc, test] 28 | - [isruc, train] 29 | - [mros, train] 30 | - [shhs, train] 31 | - [wsc, train] 32 | - [isruc, eval] 33 | - [mros, eval] 34 | - [shhs, eval] 35 | - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.Adam 56 | args: 57 | lr: 0.0001 58 | weight_decay: 0.0002 59 | amsgrad: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 62 | args: 63 | mode: 'min' 64 | factor: 0.5 65 | patience: 5 66 | verbose: true 67 | trainer: 68 | # early_stop: 10 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-ssc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-ssc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [ssc, test] 24 | - [isruc, test] 25 | - [mros, test] 26 | - [shhs, test] 27 | - [wsc, test] 28 | - [isruc, train] 29 | - [mros, train] 30 | - [shhs, train] 31 | - [wsc, train] 32 | - [isruc, eval] 33 | - [mros, eval] 34 | - [shhs, eval] 35 | - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.Adam 56 | args: 57 | lr: 0.0001 58 | weight_decay: 0 59 | amsgrad: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 62 | args: 63 | mode: 'min' 64 | factor: 0.5 65 | patience: 5 66 | verbose: true 67 | trainer: 68 | # early_stop: 10 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-wsc-sgd-clr.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-wsc-sgd-clr 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [wsc, test] 24 | - [isruc, test] 25 | - [mros, test] 26 | - [shhs, test] 27 | - [ssc, test] 28 | # - [isruc, train] 29 | # - [mros, train] 30 | # - [shhs, train] 31 | # - [ssc, train] 32 | # - [isruc, eval] 33 | # - [mros, eval] 34 | # - [shhs, eval] 35 | # - [ssc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.1 64 | max_lr: 0.5 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 10 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-wsc-wd.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-wsc-wd 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [wsc, test] 24 | - [isruc, test] 25 | - [mros, test] 26 | - [shhs, test] 27 | - [ssc, test] 28 | - [isruc, train] 29 | - [mros, train] 30 | - [shhs, train] 31 | - [ssc, train] 32 | - [isruc, eval] 33 | - [mros, eval] 34 | - [shhs, eval] 35 | - [ssc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.Adam 56 | args: 57 | lr: 0.00001 58 | weight_decay: 0.0002 59 | amsgrad: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 62 | args: 63 | mode: 'min' 64 | factor: 0.5 65 | patience: 5 66 | verbose: true 67 | trainer: 68 | # early_stop: 10 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loci-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loci-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [wsc, test] 24 | - [isruc, test] 25 | - [mros, test] 26 | - [shhs, test] 27 | - [ssc, test] 28 | - [isruc, train] 29 | - [mros, train] 30 | - [shhs, train] 31 | - [ssc, train] 32 | - [isruc, eval] 33 | - [mros, eval] 34 | - [shhs, eval] 35 | - [ssc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.Adam 56 | args: 57 | lr: 0.00001 58 | weight_decay: 0 59 | amsgrad: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 62 | args: 63 | mode: 'min' 64 | factor: 0.5 65 | patience: 5 66 | verbose: true 67 | trainer: 68 | # early_stop: 10 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp02-loco-isruc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loco-isruc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, eval] 24 | - [isruc, train] 25 | - [isruc, test] 26 | # - [mros, test] 27 | # - [shhs, test] 28 | # - [ssc, test] 29 | # - [wsc, test] 30 | # - [mros, train] 31 | # - [shhs, train] 32 | # - [ssc, train] 33 | # - [wsc, train] 34 | # - [mros, eval] 35 | # - [shhs, eval] 36 | # - [ssc, eval] 37 | # - [wsc, eval] 38 | data_dir: ./data/processed_oak 39 | modalities: [eeg, eog, emg] 40 | train_fraction: 41 | num_classes: 5 42 | segment_length: 300 # Length in seconds 43 | network: 44 | import: src.model.rnn_model.RnnModel 45 | filter_base: 4 46 | kernel_size: 3 47 | max_pooling: 2 48 | num_blocks: 7 49 | rnn_bidirectional: true 50 | rnn_num_layers: 1 51 | rnn_num_units: 1024 52 | loss: 53 | import: src.model.losses.temporal_crossentropy_loss 54 | metrics: [overall_accuracy, balanced_accuracy, kappa] 55 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 56 | optimizer: 57 | import: torch.optim.Adam 58 | args: 59 | lr: 0.0001 60 | weight_decay: 0 61 | amsgrad: true 62 | lr_scheduler: 63 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 64 | args: 65 | mode: 'min' 66 | factor: 0.5 67 | patience: 5 68 | verbose: true 69 | trainer: 70 | # early_stop: 10 71 | epochs: 100 # Number of training epochs 72 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 73 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 74 | n_gpu: 4 75 | num_workers: 128 76 | save_dir: experiments 77 | save_freq: 1 # save checkpoints every save_freq epochs 78 | tensorboardX: false # Enable tensorboardX visualization support 79 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 80 | -------------------------------------------------------------------------------- /src/configs/exp02-loco-mros.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loco-mros 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [shhs, train] 13 | - [ssc, train] 14 | - [wsc, train] 15 | eval: 16 | - [isruc, eval] 17 | - [shhs, eval] 18 | - [ssc, eval] 19 | - [wsc, eval] 20 | test: 21 | - [mros, train] 22 | - [mros, eval] 23 | - [mros, test] 24 | data_dir: ./data/processed_oak 25 | modalities: [eeg, eog, emg] 26 | train_fraction: 27 | num_classes: 5 28 | segment_length: 300 # Length in seconds 29 | network: 30 | import: src.model.rnn_model.RnnModel 31 | filter_base: 4 32 | kernel_size: 3 33 | max_pooling: 2 34 | num_blocks: 7 35 | rnn_bidirectional: true 36 | rnn_num_layers: 1 37 | rnn_num_units: 1024 38 | loss: 39 | import: src.model.losses.temporal_crossentropy_loss 40 | metrics: [overall_accuracy, balanced_accuracy, kappa] 41 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 42 | optimizer: 43 | import: torch.optim.Adam 44 | args: 45 | lr: 0.0001 46 | weight_decay: 0 47 | amsgrad: true 48 | lr_scheduler: 49 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 50 | args: 51 | mode: 'min' 52 | factor: 0.5 53 | patience: 5 54 | verbose: true 55 | trainer: 56 | # early_stop: 10 57 | epochs: 50 # Number of training epochs 58 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 59 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 60 | n_gpu: 4 61 | num_workers: 128 62 | save_dir: experiments 63 | save_freq: 1 # save checkpoints every save_freq epochs 64 | tensorboardX: false # Enable tensorboardX visualization support 65 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 66 | -------------------------------------------------------------------------------- /src/configs/exp02-loco-shhs.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loco-shhs 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [ssc, train] 14 | - [wsc, train] 15 | eval: 16 | - [isruc, eval] 17 | - [mros, eval] 18 | - [ssc, eval] 19 | - [wsc, eval] 20 | test: 21 | - [shhs, train] 22 | - [shhs, eval] 23 | - [shhs, test] 24 | data_dir: ./data/processed_oak 25 | modalities: [eeg, eog, emg] 26 | train_fraction: 27 | num_classes: 5 28 | segment_length: 300 # Length in seconds 29 | network: 30 | import: src.model.rnn_model.RnnModel 31 | filter_base: 4 32 | kernel_size: 3 33 | max_pooling: 2 34 | num_blocks: 7 35 | rnn_bidirectional: true 36 | rnn_num_layers: 1 37 | rnn_num_units: 1024 38 | loss: 39 | import: src.model.losses.temporal_crossentropy_loss 40 | metrics: [overall_accuracy, balanced_accuracy, kappa] 41 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 42 | optimizer: 43 | import: torch.optim.Adam 44 | args: 45 | lr: 0.0001 46 | weight_decay: 0 47 | amsgrad: true 48 | lr_scheduler: 49 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 50 | args: 51 | mode: 'min' 52 | factor: 0.5 53 | patience: 5 54 | verbose: true 55 | trainer: 56 | # early_stop: 10 57 | epochs: 50 # Number of training epochs 58 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 59 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 60 | n_gpu: 4 61 | num_workers: 128 62 | save_dir: experiments 63 | save_freq: 1 # save checkpoints every save_freq epochs 64 | tensorboardX: false # Enable tensorboardX visualization support 65 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 66 | -------------------------------------------------------------------------------- /src/configs/exp02-loco-ssc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loco-ssc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [wsc, train] 15 | eval: 16 | - [isruc, eval] 17 | - [mros, eval] 18 | - [shhs, eval] 19 | - [wsc, eval] 20 | test: 21 | - [ssc, train] 22 | - [ssc, eval] 23 | - [ssc, test] 24 | data_dir: ./data/processed_oak 25 | modalities: [eeg, eog, emg] 26 | train_fraction: 27 | num_classes: 5 28 | segment_length: 300 # Length in seconds 29 | network: 30 | import: src.model.rnn_model.RnnModel 31 | filter_base: 4 32 | kernel_size: 3 33 | max_pooling: 2 34 | num_blocks: 7 35 | rnn_bidirectional: true 36 | rnn_num_layers: 1 37 | rnn_num_units: 1024 38 | loss: 39 | import: src.model.losses.temporal_crossentropy_loss 40 | metrics: [overall_accuracy, balanced_accuracy, kappa] 41 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 42 | optimizer: 43 | import: torch.optim.Adam 44 | args: 45 | lr: 0.0001 46 | weight_decay: 0 47 | amsgrad: true 48 | lr_scheduler: 49 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 50 | args: 51 | mode: 'min' 52 | factor: 0.5 53 | patience: 5 54 | verbose: true 55 | trainer: 56 | # early_stop: 10 57 | epochs: 50 # Number of training epochs 58 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 59 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 60 | n_gpu: 4 61 | num_workers: 128 62 | save_dir: experiments 63 | save_freq: 1 # save checkpoints every save_freq epochs 64 | tensorboardX: false # Enable tensorboardX visualization support 65 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 66 | -------------------------------------------------------------------------------- /src/configs/exp02-loco-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-loco-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 64 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | eval: 16 | - [isruc, eval] 17 | - [mros, eval] 18 | - [shhs, eval] 19 | - [ssc, eval] 20 | test: 21 | - [wsc, train] 22 | - [wsc, eval] 23 | - [wsc, test] 24 | data_dir: ./data/processed_oak 25 | modalities: [eeg, eog, emg] 26 | train_fraction: 27 | num_classes: 5 28 | segment_length: 300 # Length in seconds 29 | network: 30 | import: src.model.rnn_model.RnnModel 31 | filter_base: 4 32 | kernel_size: 3 33 | max_pooling: 2 34 | num_blocks: 7 35 | rnn_bidirectional: true 36 | rnn_num_layers: 1 37 | rnn_num_units: 1024 38 | loss: 39 | import: src.model.losses.temporal_crossentropy_loss 40 | metrics: [overall_accuracy, balanced_accuracy, kappa] 41 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 42 | optimizer: 43 | import: torch.optim.Adam 44 | args: 45 | lr: 0.00001 46 | weight_decay: 0 47 | amsgrad: true 48 | lr_scheduler: 49 | import: torch.optim.lr_scheduler.ReduceLROnPlateau 50 | args: 51 | mode: 'min' 52 | factor: 0.5 53 | patience: 5 54 | verbose: true 55 | trainer: 56 | # early_stop: 10 57 | epochs: 50 # Number of training epochs 58 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 59 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 60 | n_gpu: 4 61 | num_workers: 128 62 | save_dir: experiments 63 | save_freq: 1 # save checkpoints every save_freq epochs 64 | tensorboardX: false # Enable tensorboardX visualization support 65 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 66 | -------------------------------------------------------------------------------- /src/configs/exp02-mros.yaml: -------------------------------------------------------------------------------- 1 | __include__: abstract.yaml 2 | exp: 3 | name: exp02-mros 4 | data_loader: 5 | import: src.data_loader.dataset.MultiCohortDataset 6 | batch_size: 7 | train: 128 8 | eval: 64 9 | test: 64 10 | data: 11 | train: 12 | # - [isruc, train] 13 | - [mros, train] 14 | # - [shhs, train] 15 | # - [ssc, train] 16 | # - [wsc, train] 17 | eval: 18 | # - [isruc, eval] 19 | - [mros, eval] 20 | # - [shhs, eval] 21 | # - [ssc, eval] 22 | # - [wsc, eval] 23 | test: 24 | # - [isruc, test] 25 | - [mros, test] 26 | # - [shhs, test] 27 | # - [ssc, test] 28 | # - [wsc, test] 29 | data_dir: ./data/processed_oak 30 | modalities: [eeg, eog, emg] 31 | data_fraction: 0.05 32 | num_classes: 5 33 | segment_length: 300 # Length in seconds 34 | network: 35 | import: src.model.rnn_model.RnnModel 36 | filter_base: 4 37 | kernel_size: 3 38 | max_pooling: 2 39 | num_blocks: 7 40 | rnn_bidirectional: true 41 | rnn_num_layers: 1 42 | rnn_num_units: 256 43 | loss: 44 | import: src.model.losses.nll_loss 45 | metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | optimizer: 47 | import: Adam 48 | args: 49 | lr: 0.01 50 | weight_decay: 0 51 | amsgrad: true 52 | lr_scheduler: 53 | import: ReduceLROnPlateau 54 | args: 55 | mode: 'min' 56 | factor: 0.5 57 | patience: 5 58 | verbose: true 59 | trainer: 60 | early_stop: 10 61 | epochs: 100 # Number of training epochs 62 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 63 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 64 | n_gpu: 4 65 | num_workers: 0 66 | save_dir: experiments 67 | save_freq: 1 # save checkpoints every save_freq epochs 68 | tensorboardX: false # Enable tensorboardX visualization support 69 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 70 | -------------------------------------------------------------------------------- /src/configs/exp02-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 32 8 | test: 32 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | num_classes: 5 31 | segment_length: 300 # Length in seconds 32 | network: 33 | import: src.model.rnn_model.RnnModel 34 | filter_base: 4 35 | kernel_size: 3 36 | max_pooling: 2 37 | num_blocks: 7 38 | rnn_bidirectional: true 39 | rnn_num_layers: 1 40 | rnn_num_units: 41 | loss: 42 | import: src.model.losses.nll_loss 43 | metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 44 | optimizer: 45 | import: Adam 46 | args: 47 | lr: 0.001 48 | weight_decay: 0 49 | amsgrad: true 50 | trainer: 51 | early_stop: 10 52 | epochs: 100 # Number of training epochs 53 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 54 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 55 | n_gpu: 1 56 | num_workers: 128 57 | save_dir: experiments 58 | save_freq: 1 # save checkpoints every save_freq epochs 59 | tensorboardX: false # Enable tensorboardX visualization support 60 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 61 | -------------------------------------------------------------------------------- /src/configs/exp03-frac100-wd.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp03-frac100-wd 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | # optimizer: 47 | # import: torch.optim.Adam 48 | # args: 49 | # lr: 0.1 50 | # momentum: 0.9 51 | # nesterov: true 52 | # lr_scheduler: 53 | # import: torch.optim.lr_scheduler.CyclicLR 54 | # args: 55 | # base_lr: 0.05 56 | # max_lr: 0.15 57 | # step_size_up: 500 58 | # mode: 'triangular' 59 | # optimizer: 60 | # import: torch.optim.Adam 61 | # args: 62 | # lr: 0.0001 63 | # weight_decay: 0 64 | # amsgrad: true 65 | # lr_scheduler: 66 | # import: torch.optim.lr_scheduler.ReduceLROnPlateau 67 | # args: 68 | # mode: 'min' 69 | # factor: 0.5 70 | # patience: 5 71 | # verbose: true 72 | optimizer: 73 | import: torch.optim.SGD 74 | args: 75 | lr: 0.1 76 | momentum: 0.9 77 | nesterov: true 78 | weight_decay: 0.0005 79 | lr_scheduler: 80 | import: torch.optim.lr_scheduler.CyclicLR 81 | args: 82 | base_lr: 0.05 83 | max_lr: 0.15 84 | step_size_up: 500 85 | mode: 'triangular' 86 | trainer: 87 | # early_stop: 11 88 | epochs: 40 # Number of training epochs 89 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 90 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 91 | n_gpu: 4 92 | num_workers: 128 93 | save_dir: experiments 94 | save_freq: 1 # save checkpoints every save_freq epochs 95 | tensorboardX: false # Enable tensorboardX visualization support 96 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 97 | -------------------------------------------------------------------------------- /src/configs/exp03-frac100.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp03-frac100 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | train_fraction: 31 | num_classes: 5 32 | segment_length: 300 # Length in seconds 33 | network: 34 | import: src.model.rnn_model.RnnModel 35 | filter_base: 4 36 | kernel_size: 3 37 | max_pooling: 2 38 | num_blocks: 7 39 | rnn_bidirectional: true 40 | rnn_num_layers: 1 41 | rnn_num_units: 1024 42 | loss: 43 | import: src.model.losses.temporal_crossentropy_loss 44 | metrics: [overall_accuracy, balanced_accuracy, kappa] 45 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 46 | # optimizer: 47 | # import: torch.optim.Adam 48 | # args: 49 | # lr: 0.1 50 | # momentum: 0.9 51 | # nesterov: true 52 | # lr_scheduler: 53 | # import: torch.optim.lr_scheduler.CyclicLR 54 | # args: 55 | # base_lr: 0.05 56 | # max_lr: 0.15 57 | # step_size_up: 500 58 | # mode: 'triangular' 59 | # optimizer: 60 | # import: torch.optim.Adam 61 | # args: 62 | # lr: 0.0001 63 | # weight_decay: 0 64 | # amsgrad: true 65 | # lr_scheduler: 66 | # import: torch.optim.lr_scheduler.ReduceLROnPlateau 67 | # args: 68 | # mode: 'min' 69 | # factor: 0.5 70 | # patience: 5 71 | # verbose: true 72 | optimizer: 73 | import: torch.optim.SGD 74 | args: 75 | lr: 0.1 76 | momentum: 0.9 77 | nesterov: true 78 | lr_scheduler: 79 | import: torch.optim.lr_scheduler.CyclicLR 80 | args: 81 | base_lr: 0.05 82 | max_lr: 0.15 83 | step_size_up: 500 84 | mode: 'triangular' 85 | trainer: 86 | # early_stop: 11 87 | epochs: 40 # Number of training epochs 88 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 89 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 90 | n_gpu: 4 91 | num_workers: 128 92 | save_dir: experiments 93 | save_freq: 1 # save checkpoints every save_freq epochs 94 | tensorboardX: false # Enable tensorboardX visualization support 95 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 96 | -------------------------------------------------------------------------------- /src/configs/exp03-isruc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp03-isruc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.1 64 | max_lr: 0.5 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp03-mros.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp03-mros 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp03-shhs.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp03-shhs 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp03-ssc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp03-ssc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.1 64 | max_lr: 0.25 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp03-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp03-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.1 64 | max_lr: 0.25 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-mros-shhs-ssc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-mros-shhs-ssc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-mros-shhs-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-mros-shhs-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | # - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-mros-shhs.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-mros-shhs 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-mros-ssc-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-mros-ssc-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | # - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | # - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-mros-ssc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-mros-ssc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | # - [shhs, train] 14 | - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | # - [shhs, eval] 20 | - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-mros-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-mros-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-mros.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-mros 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-shhs-ssc-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-shhs-ssc-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-shhs-ssc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-shhs-ssc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-shhs-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-shhs-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | - [shhs, eval] 20 | # - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-shhs.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-shhs 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-ssc-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-ssc-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-ssc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-ssc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-isruc-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-isruc-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-mros-shhs-ssc-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-mros-shhs-ssc-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-mros-shhs-ssc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-mros-shhs-ssc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-mros-shhs-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-mros-shhs-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | # - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-mros-shhs.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-mros-shhs 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-mros-ssc-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-mros-ssc-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | # - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | - [mros, eval] 19 | # - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-mros-ssc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-mros-ssc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | # - [shhs, train] 14 | - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | - [mros, eval] 19 | # - [shhs, eval] 20 | - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-mros-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-mros-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-shhs-ssc-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-shhs-ssc-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-shhs-ssc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-shhs-ssc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | - [shhs, train] 14 | - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-shhs-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-shhs-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | - [shhs, train] 14 | # - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | - [shhs, eval] 20 | # - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/exp04-ssc-wsc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp04-ssc-wsc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 256 8 | test: 256 9 | data: 10 | train: 11 | # - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | - [ssc, train] 15 | - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | # - [mros, train] 29 | # - [shhs, train] 30 | # - [ssc, train] 31 | # - [wsc, train] 32 | # - [mros, eval] 33 | # - [shhs, eval] 34 | # - [ssc, eval] 35 | # - [wsc, eval] 36 | data_dir: ./data/processed_oak 37 | modalities: [eeg, eog, emg] 38 | train_fraction: 500 39 | num_classes: 5 40 | segment_length: 300 # Length in seconds 41 | network: 42 | import: src.model.rnn_model.RnnModel 43 | filter_base: 4 44 | kernel_size: 3 45 | max_pooling: 2 46 | num_blocks: 7 47 | rnn_bidirectional: true 48 | rnn_num_layers: 1 49 | rnn_num_units: 1024 50 | loss: 51 | import: src.model.losses.temporal_crossentropy_loss 52 | metrics: [overall_accuracy, balanced_accuracy, kappa] 53 | # metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 54 | optimizer: 55 | import: torch.optim.SGD 56 | args: 57 | lr: 0.1 58 | momentum: 0.9 59 | nesterov: true 60 | lr_scheduler: 61 | import: torch.optim.lr_scheduler.CyclicLR 62 | args: 63 | base_lr: 0.05 64 | max_lr: 0.15 65 | step_size_up: 500 66 | mode: 'triangular' 67 | trainer: 68 | # early_stop: 11 69 | epochs: 50 # Number of training epochs 70 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 71 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 72 | n_gpu: 4 73 | num_workers: 128 74 | save_dir: experiments 75 | save_freq: 1 # save checkpoints every save_freq epochs 76 | tensorboardX: false # Enable tensorboardX visualization support 77 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 78 | -------------------------------------------------------------------------------- /src/configs/isruc.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: exp02-isruc 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 256 7 | eval: 32 8 | test: 32 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | # - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | num_classes: 5 31 | segment_length: 300 # Length in seconds 32 | network: 33 | import: src.model.rnn_model.RnnModel 34 | filter_base: 4 35 | kernel_size: 3 36 | max_pooling: 2 37 | num_blocks: 7 38 | rnn_bidirectional: true 39 | rnn_num_layers: 1 40 | rnn_num_units: 41 | loss: 42 | import: src.model.losses.nll_loss 43 | metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 44 | optimizer: 45 | import: Adam 46 | args: 47 | lr: 0.001 48 | weight_decay: 0 49 | amsgrad: true 50 | trainer: 51 | early_stop: 10 52 | epochs: 100 # Number of training epochs 53 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 54 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 55 | n_gpu: 4 56 | num_workers: 128 57 | save_dir: experiments 58 | save_freq: 1 # save checkpoints every save_freq epochs 59 | tensorboardX: false # Enable tensorboardX visualization support 60 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 61 | -------------------------------------------------------------------------------- /src/configs/test-dataset.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: test-dataset 3 | data_loader: 4 | name: dataset.MultiCohortDataset 5 | batch_size: 6 | train: 64 7 | eval: 32 8 | test: 64 9 | data: 10 | train: 11 | # - [isruc, train] 12 | - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | # - [isruc, eval] 18 | - [mros, eval] 19 | # - [shhs, eval] 20 | # - [ssc, eval] 21 | # - [wsc, eval] 22 | test: 23 | # - [isruc, test] 24 | - [mros, test] 25 | # - [shhs, test] 26 | # - [ssc, test] 27 | # - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | num_classes: 5 31 | segment_length: 300 # Length in seconds 32 | -------------------------------------------------------------------------------- /src/configs/test-rnn_model.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: test-rnn_model 3 | data_loader: 4 | name: dataset.MultiCohortDataset 5 | batch_size: 6 | train: 32 7 | eval: 32 8 | test: 32 9 | data: 10 | train: 11 | - [isruc, train] 12 | - [mros, train] 13 | - [shhs, train] 14 | - [wsc, train] 15 | - [isruc, eval] 16 | - [mros, eval] 17 | - [shhs, eval] 18 | - [wsc, eval] 19 | - [isruc, test] 20 | - [mros, test] 21 | - [shhs, test] 22 | - [wsc, test] 23 | eval: 24 | - [isruc, eval] 25 | - [mros, eval] 26 | - [shhs, eval] 27 | - [wsc, eval] 28 | test: 29 | - [isruc, test] 30 | - [mros, test] 31 | - [shhs, test] 32 | - [wsc, test] 33 | data_dir: ./data/processed 34 | modalities: [eeg, eog, emg] 35 | num_classes: 5 36 | segment_length: 300 # Length in seconds 37 | network: 38 | import: src.model.rnn_model.RnnModel 39 | filter_base: 4 40 | kernel_size: 3 41 | max_pooling: 2 42 | num_blocks: 7 43 | rnn_bidirectional: true 44 | rnn_num_layers: 1 45 | rnn_num_units: 46 | loss: 47 | name: nll_loss 48 | metrics: -------------------------------------------------------------------------------- /src/configs/test.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: test 3 | data_loader: 4 | import: src.data_loader.dataset.MultiCohortDataset 5 | batch_size: 6 | train: 64 7 | eval: 32 8 | test: 32 9 | data: 10 | train: 11 | - [isruc, train] 12 | # - [mros, train] 13 | # - [shhs, train] 14 | # - [ssc, train] 15 | # - [wsc, train] 16 | eval: 17 | - [isruc, eval] 18 | - [mros, eval] 19 | - [shhs, eval] 20 | - [ssc, eval] 21 | - [wsc, eval] 22 | test: 23 | - [isruc, test] 24 | - [mros, test] 25 | - [shhs, test] 26 | - [ssc, test] 27 | - [wsc, test] 28 | data_dir: ./data/processed_oak 29 | modalities: [eeg, eog, emg] 30 | num_classes: 5 31 | segment_length: 300 # Length in seconds 32 | network: 33 | import: src.model.rnn_model.RnnModel 34 | filter_base: 4 35 | kernel_size: 3 36 | max_pooling: 2 37 | num_blocks: 7 38 | rnn_bidirectional: true 39 | rnn_num_layers: 1 40 | rnn_num_units: 41 | loss: 42 | import: src.model.losses.nll_loss 43 | metrics: ['overall_accuracy', 'balanced_accuracy', 'kappa', 'balanced_precision', 'overall_precision', 'balanced_recall', 'overall_recall', 'balanced_f1', 'overall_f1'] 44 | optimizer: 45 | import: Adam 46 | args: 47 | lr: 0.001 48 | weight_decay: 0 49 | amsgrad: true 50 | trainer: 51 | early_stop: 10 52 | epochs: 100 # Number of training epochs 53 | log_dir: experiments/runs # Directory in which to save log files for tensorboardX visualization 54 | monitor: min val_loss # de and metric for model performance monitoring. set 'off' to disable. 55 | n_gpu: 4 56 | num_workers: 128 57 | save_dir: experiments 58 | save_freq: 1 # save checkpoints every save_freq epochs 59 | tensorboardX: false # Enable tensorboardX visualization support 60 | verbosity: 2 # 0: quiet, 1: per epoch, 2: full 61 | -------------------------------------------------------------------------------- /src/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neergaard/deep-sleep-pytorch/db5d6f59369baacabeac956201c3c6de09961945/src/data/.gitkeep -------------------------------------------------------------------------------- /src/data_loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neergaard/deep-sleep-pytorch/db5d6f59369baacabeac956201c3c6de09961945/src/data_loader/__init__.py -------------------------------------------------------------------------------- /src/data_loader/data_loaders.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from src.base import BaseDataLoader 3 | 4 | 5 | class DataLoader(BaseDataLoader): 6 | """ 7 | Something interesting. 8 | """ 9 | 10 | def __init__(self, data_dir, batch_size, shuffle, validation_split, num_workers, training=True): 11 | trsfm = transforms.Compose([ 12 | transforms.ToTensor(), 13 | transforms.Normalize((0.1307,), (0.3081,)) 14 | ]) 15 | self.data_dir = data_dir 16 | self.dataset = datasets.MNIST( 17 | self.data_dir, train=training, download=True, transform=trsfm) 18 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 19 | 20 | 21 | if __name__ == '__main__': 22 | 23 | print('') 24 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neergaard/deep-sleep-pytorch/db5d6f59369baacabeac956201c3c6de09961945/src/model/__init__.py -------------------------------------------------------------------------------- /src/model/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch.nn import CrossEntropyLoss 3 | 4 | 5 | def nll_loss(output, target): 6 | return F.nll_loss(output, target.long()) 7 | 8 | 9 | def temporal_crossentropy_loss(output, target): 10 | return F.cross_entropy(output, target, reduction='mean') 11 | 12 | 13 | if __name__ == '__main__': 14 | 15 | import numpy as np 16 | import torch 17 | from torch.autograd import Variable 18 | 19 | outputs = np.random.randint(0, 4, size=(32, 5, 120)) 20 | outputs = Variable(torch.from_numpy(outputs).float().cuda()) 21 | 22 | from src.utils.config import process_config 23 | 24 | config = process_config('./src/configs/test-rnn_model.yaml') 25 | model = RnnModel(config).cuda() 26 | n_channels = model.num_channels 27 | length = config.data_loader.segment_length * 128 28 | x = np.random.randn(config.data_loader.batch_size.train, 1, 29 | n_channels, length) 30 | x = Variable(torch.from_numpy(x).float()).cuda() 31 | print(model) 32 | z = model(x) 33 | print(z.shape) 34 | -------------------------------------------------------------------------------- /src/model/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn import metrics 3 | 4 | 5 | def my_metric(output, target): 6 | with torch.no_grad(): 7 | pred = torch.argmax(output, dim=1) 8 | assert pred.shape[0] == len(target) 9 | correct = 0 10 | correct += torch.sum(pred == target).item() 11 | return correct / len(target) 12 | 13 | 14 | def my_metric2(output, target, k=3): 15 | with torch.no_grad(): 16 | pred = torch.topk(output, k, dim=1)[1] 17 | assert pred.shape[0] == len(target) 18 | correct = 0 19 | for i in range(k): 20 | correct += torch.sum(pred[:, i] == target).item() 21 | return correct / len(target) 22 | 23 | 24 | def overall_accuracy(output, target): 25 | return metrics.accuracy_score(target.data.cpu().numpy().flatten(), output.data.cpu().numpy().argmax(1).flatten()) 26 | 27 | 28 | def balanced_accuracy(output, target): 29 | return metrics.balanced_accuracy_score(target.data.cpu().numpy().flatten(), output.data.cpu().numpy().argmax(1).flatten()) 30 | 31 | 32 | def kappa(output, target): 33 | return metrics.cohen_kappa_score(output.data.cpu().numpy().argmax(1).flatten(), target.data.cpu().numpy().flatten(), labels=[0, 1, 2, 3, 4]) 34 | 35 | 36 | def precision(output, target): 37 | return metrics.precision_score(target.data.cpu().numpy().flatten(), output.data.cpu().numpy().argmax(1).flatten(), labels=[0, 1, 2, 3, 4], average='macro') 38 | 39 | 40 | def balanced_precision(output, target): 41 | return metrics.precision_score(target.data.cpu().numpy().flatten(), output.data.cpu().numpy().argmax(1).flatten(), labels=[0, 1, 2, 3, 4], average='weighted') 42 | 43 | 44 | def overall_precision(output, target): 45 | return metrics.precision_score(target.data.cpu().numpy().flatten(), output.data.cpu().numpy().argmax(1).flatten(), labels=[0, 1, 2, 3, 4], average='micro') 46 | 47 | 48 | def recall(output, target): 49 | return metrics.recall_score(target.data.cpu().numpy().flatten(), output.data.cpu().numpy().argmax(1).flatten(), labels=[0, 1, 2, 3, 4], average='macro') 50 | 51 | 52 | def balanced_recall(output, target): 53 | return metrics.recall_score(target.data.cpu().numpy().flatten(), output.data.cpu().numpy().argmax(1).flatten(), labels=[0, 1, 2, 3, 4], average='weighted') 54 | 55 | 56 | def overall_recall(output, target): 57 | return metrics.recall_score(target.data.cpu().numpy().flatten(), output.data.cpu().numpy().argmax(1).flatten(), labels=[0, 1, 2, 3, 4], average='micro') 58 | 59 | 60 | def f1(output, target): 61 | return metrics.f1_score(target.data.cpu().numpy().flatten(), output.data.cpu().numpy().argmax(1).flatten(), labels=[0, 1, 2, 3, 4], average='macro') 62 | 63 | 64 | def balanced_f1(output, target): 65 | return metrics.f1_score(target.data.cpu().numpy().flatten(), output.data.cpu().numpy().argmax(1).flatten(), labels=[0, 1, 2, 3, 4], average='weighted') 66 | 67 | 68 | def overall_f1(output, target): 69 | return metrics.f1_score(target.data.cpu().numpy().flatten(), output.data.cpu().numpy().argmax(1).flatten(), labels=[0, 1, 2, 3, 4], average='weighted') 70 | -------------------------------------------------------------------------------- /src/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neergaard/deep-sleep-pytorch/db5d6f59369baacabeac956201c3c6de09961945/src/trainer/__init__.py -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neergaard/deep-sleep-pytorch/db5d6f59369baacabeac956201c3c6de09961945/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import yaml 5 | from dotmap import DotMap 6 | 7 | 8 | def get_config_from_yaml(yaml_file): 9 | """ 10 | Get the config from a yaml file 11 | :param yaml_file: 12 | :return: config(namespace) or config(dictionary) 13 | """ 14 | # parse the configurations from the config yaml file provided 15 | with open(yaml_file, 'r') as config_file: 16 | config_dict = yaml.safe_load(config_file) 17 | 18 | # convert the dictionary to a namespace using bunch lib 19 | config = DotMap(config_dict) 20 | 21 | return config, config_dict 22 | 23 | 24 | def process_config(yaml_file): 25 | config, _ = get_config_from_yaml(yaml_file) 26 | 27 | return config 28 | -------------------------------------------------------------------------------- /src/utils/ensure_dir.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def ensure_dir(path): 5 | if not os.path.exists(path): 6 | os.makedirs(path) 7 | -------------------------------------------------------------------------------- /src/utils/factory.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def create_instance(config_map): 5 | '''Expects a string that can be imported as with a module.class name''' 6 | module_name, class_name = config_map['import'].rsplit(".", 1) 7 | 8 | try: 9 | print('Importing {}.{}'.format(module_name, class_name)) 10 | somemodule = importlib.import_module(module_name) 11 | # print('getattr '+class_name) 12 | cls_instance = getattr(somemodule, class_name) 13 | # print(cls_instance) 14 | except Exception as err: 15 | print("Creating error: {0}".format(err)) 16 | exit(-1) 17 | 18 | return cls_instance 19 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | logging.basicConfig(level=logging.INFO, format='') 5 | 6 | 7 | class Logger: 8 | """ 9 | Training process logger 10 | 11 | Note: 12 | Used by BaseTrainer to save training history. 13 | """ 14 | 15 | def __init__(self): 16 | self.entries = {} 17 | 18 | def add_entry(self, entry): 19 | self.entries[len(self.entries) + 1] = entry 20 | 21 | def __str__(self): 22 | return json.dumps(self.entries, sort_keys=True, indent=4) 23 | -------------------------------------------------------------------------------- /src/utils/parallel_bar.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from tqdm import tqdm 4 | with warnings.catch_warnings(): 5 | warnings.simplefilter('ignore', category=UserWarning) 6 | from joblib import Parallel, delayed 7 | import time 8 | 9 | import random 10 | 11 | def text_progessbar(seq, total=None): 12 | step = 1 13 | tick = time.time() 14 | while True: 15 | time_diff = time.time()-tick 16 | avg_speed = time_diff/step 17 | total_str = 'of %n' % total if total else '' 18 | print('step', step, '%.2f' % time_diff, 'avg: %.2f iter/sec' % avg_speed, total_str) 19 | step += 1 20 | yield next(seq) 21 | 22 | all_bar_funcs = { 23 | 'tqdm': lambda args: lambda x: tqdm(x, **args), 24 | 'txt': lambda args: lambda x: text_progessbar(x, **args), 25 | 'False': lambda args: iter, 26 | 'None': lambda args: iter, 27 | } 28 | 29 | def ParallelExecutor(use_bar='tqdm', **joblib_args): 30 | def aprun(bar=use_bar, **tq_args): 31 | def tmp(op_iter): 32 | if str(bar) in all_bar_funcs.keys(): 33 | bar_func = all_bar_funcs[str(bar)](tq_args) 34 | else: 35 | raise ValueError("Value %s not supported as bar type"%bar) 36 | return Parallel(**joblib_args)(bar_func(op_iter)) 37 | return tmp 38 | return aprun 39 | -------------------------------------------------------------------------------- /src/utils/pickle_reader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | 5 | 6 | def read_pickle(filepath: str): 7 | """Read contents of output pickle files. 8 | Usage: 9 | from pickle_reader import read_pickle 10 | 11 | y, t = read_pickle(filename) 12 | 13 | Args: 14 | filepath: str -- path to pickle file 15 | 16 | Returns: 17 | y: ndarray -- output predictions every second 18 | t: ndarray -- targets every second (hypnogram labels) 19 | """ 20 | 21 | # This loads contents of pickle file into a dict 22 | with open(filepath, "rb") as pkl: 23 | out = pickle.load(pkl) 24 | 25 | # Isolate and reshape targets from (N, 300) into (N * 300,) 26 | targets = np.asarray(out["targets"]).reshape(-1) 27 | 28 | # Isolate and reshape predictions from (N, 5, 300) into (N * 300, 5) 29 | predictions = np.asarray(out["predictions"]) 30 | N, K, T = predictions.shape 31 | predictions = predictions.transpose(1, 0, 2).reshape(K, N * T).T 32 | 33 | return predictions, targets 34 | -------------------------------------------------------------------------------- /src/utils/segmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rolling_window(a, window, step): 5 | """ 6 | Make an ndarray with a rolling window of the last dimension with a given step size. 7 | 8 | Parameters 9 | ---------- 10 | a : array_like 11 | Array to add rolling window to 12 | window : int 13 | Size of rolling window 14 | step : int 15 | Size of steps between windows 16 | 17 | Returns 18 | ------- 19 | Array that is a view of the original array with a added dimension 20 | of size window. 21 | 22 | Examples 23 | -------- 24 | >>> x=np.arange(10).reshape((2,5)) 25 | >>> rolling_window(x, 3) 26 | array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]], 27 | [[5, 6, 7], [6, 7, 8], [7, 8, 9]]]) 28 | 29 | Calculate rolling mean of last dimension: 30 | >>> np.mean(rolling_window(x, 3), -1) 31 | array([[ 1., 2., 3.], 32 | [ 6., 7., 8.]]) 33 | 34 | """ 35 | if window < 1: 36 | raise ValueError("`window` must be at least 1.") 37 | if window > a.shape[-1]: 38 | raise ValueError("`window` is too long.") 39 | shape = a.shape[:-1] + ((a.shape[-1] - window) // step + 1, window) 40 | strides = a.strides[:-1] + (step * a.strides[-1], a.strides[-1]) 41 | return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) 42 | 43 | 44 | def segment_signal(signal, segment_length, segment_step): 45 | return rolling_window(signal, segment_length, segment_step) 46 | 47 | 48 | def segment_and_sequence_psg(segmentation, fs, psg, hypnogram): 49 | psg_seg = {key: [] for key in psg.keys()} 50 | num_segments_in_sequence = int( 51 | segmentation['SEQUENCE_DURATION_MIN'] / segmentation['EPOCH_LENGTH_SEC'] * 60) 52 | for chn in psg_seg.keys(): 53 | segment_length_sec = segmentation['SEGMENT_LENGTH'][chn] 54 | segment_overlap = segmentation['SEGMENT_OVERLAP'][chn] 55 | segment_length = int(segment_length_sec * fs) 56 | segment_step = int((segment_length_sec - segment_overlap) * fs) 57 | 58 | psg_seg[chn] = segment_signal(psg[chn], segment_length, segment_step) 59 | 60 | # # Create sequences 61 | # psg_seg[chn] = np.reshape(psg_seg[chn], [psg_seg[chn].shape[0], num_segments_in_sequence, -1, segment_length]) 62 | 63 | # Segment the hypnogram 64 | hypno_seg = segment_signal( 65 | hypnogram, num_segments_in_sequence, num_segments_in_sequence) 66 | 67 | # # If the signals and hypnogram are of different lengths, we assume that the start time is fixed for both, 68 | # # so we trim the ends 69 | trim_length = np.min([hypno_seg.shape[0], psg_seg['eeg'].shape[1]]) 70 | hypno_seg = hypno_seg[:num_segments_in_sequence * 71 | (trim_length//num_segments_in_sequence), :] 72 | psg_seg = {chn: sig[:, :num_segments_in_sequence * 73 | (trim_length//num_segments_in_sequence), :] for chn, sig in psg_seg.items()} 74 | 75 | # # Need to make sure the duration is divisible by the sequence length 76 | # psg_seg = {chn: np.reshape(sig, [sig.shape[0], -1, num_segments_in_sequence, segment_length]) for chn, sig in psg_seg.items()} 77 | # hypno_seg = np.reshape(hypno, [-1, num_segments_in_sequence]) 78 | # 79 | # # Tranpose so that we have "batch size" ie. number of sequences in the first dimension, time step in the second 80 | # # dimension (ie. number of segments in sequence), number of channels, number of features (ie. 1), and number of time 81 | # # steps in segment (N, T, C, H, W) 82 | # psg_seg = {chn: np.transpose(sig, axes=[1, 0])} 83 | 84 | return psg_seg, hypno_seg 85 | 86 | 87 | def segmentPSG(segmentation, fs, psg): 88 | psg_seg = {key: [] for key in psg.keys()} 89 | 90 | for chn in psg_seg.keys(): 91 | segmentLength_sec = segmentation['SEGMENT_LENGTH'][chn] 92 | segmentOverlap = segmentation['SEGMENT_OVERLAP'][chn] 93 | segmentLength = int(segmentLength_sec * fs) 94 | segmentStep = int((segmentLength_sec - segmentOverlap) * fs) 95 | 96 | psg_seg[chn] = segment_signal(psg[chn], segmentLength, segmentStep) 97 | 98 | return psg_seg 99 | -------------------------------------------------------------------------------- /src/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | class WriterTensorboardX(): 5 | def __init__(self, writer_dir, logger, enable): 6 | self.writer = None 7 | if enable: 8 | log_path = writer_dir 9 | try: 10 | self.writer = importlib.import_module( 11 | 'tensorboardX').SummaryWriter(log_path) 12 | except ImportError: 13 | message = "Warning: TensorboardX visualization is configured to use, but currently not installed on " \ 14 | "this machine. Please install the package by 'pip install tensorboardx' command or turn " \ 15 | "off the option in the 'config.json' file." 16 | logger.warning(message) 17 | self.step = 0 18 | self.mode = '' 19 | 20 | self.tb_writer_ftns = [ 21 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 22 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 23 | ] 24 | self.tag_mode_exceptions = ['add_histogram', 'add_embedding'] 25 | 26 | def set_step(self, step, mode='train'): 27 | self.mode = mode 28 | self.step = step 29 | 30 | def __getattr__(self, name): 31 | """ 32 | If visualization is configured to use: 33 | return add_data() methods of tensorboard with additional information (step, tag) added. 34 | Otherwise: 35 | return a blank function handle that does nothing 36 | """ 37 | if name in self.tb_writer_ftns: 38 | add_data = getattr(self.writer, name, None) 39 | 40 | def wrapper(tag, data, *args, **kwargs): 41 | if add_data is not None: 42 | # add mode(train/valid) tag 43 | if name not in self.tag_mode_exceptions: 44 | tag = '{}/{}'.format(self.mode, tag) 45 | add_data(tag, data, self.step, *args, **kwargs) 46 | return wrapper 47 | else: 48 | # default action for returning methods defined in this class, set_step() for instance. 49 | try: 50 | attr = object.__getattr__(name) 51 | except AttributeError: 52 | raise AttributeError( 53 | "type object 'WriterTensorboardX' has no attribute '{}'".format(name)) 54 | return attr 55 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import wandb 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | import src.model.metrics as module_metric 13 | from src.trainer.trainer import Trainer 14 | from src.utils.config import process_config 15 | from src.utils.factory import create_instance 16 | from src.utils.logger import Logger 17 | 18 | 19 | # Reproducibility 20 | np.random.seed(1337) 21 | torch.manual_seed(1337) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = False 24 | 25 | 26 | def main(config, resume): 27 | train_logger = Logger() 28 | 29 | # Setup data_loader instances 30 | subsets = ['train', 'eval'] 31 | datasets = {subset: create_instance(config.data_loader)( 32 | config, subset=subset) for subset in subsets} 33 | data_loaders = {subset: DataLoader(datasets[subset], 34 | batch_size=datasets[subset].batch_size, 35 | shuffle=True if subset is 'train' else False, 36 | num_workers=config.trainer.num_workers, 37 | drop_last=True, 38 | pin_memory=True) for subset in subsets} 39 | 40 | # build model architecture 41 | model = create_instance(config.network)(config) 42 | print(model) 43 | wandb.watch(model) 44 | 45 | # get function handles of loss and metrics 46 | loss = create_instance(config.loss) 47 | metrics = [getattr(module_metric, met) for met in config['metrics']] 48 | 49 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 50 | optimizer = create_instance(config.optimizer) 51 | lr_scheduler = create_instance(config.lr_scheduler) 52 | 53 | trainer = Trainer(model, loss, metrics, optimizer, 54 | resume=resume, 55 | config=config, 56 | data_loader=data_loaders['train'], 57 | valid_data_loader=data_loaders['eval'], 58 | lr_scheduler=lr_scheduler, 59 | train_logger=train_logger) 60 | 61 | trainer.train() 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser(description='DeepSleep') 66 | parser.add_argument('-c', '--config', default=None, type=str, 67 | help='config file path (default: None)') 68 | parser.add_argument('-r', '--resume', default=None, type=str, 69 | help='path to latest checkpoint (default: None)') 70 | parser.add_argument('-d', '--device', default=None, type=str, 71 | help='indices of GPUs to enable (default: all)') 72 | args = parser.parse_args() 73 | 74 | # DEBUGGING: 75 | args.config = 'src/configs/exp03-frac100.yaml' 76 | # args.resume = 'experiments/exp01-hu2048/0502_122808/checkpoint-epoch39.pth' 77 | # args.device = '0' 78 | 79 | if args.config: 80 | # load config file 81 | config = process_config(args.config) 82 | 83 | # setting path to save trained models and log files 84 | path = os.path.join(config.trainer.save_dir, config.exp.name) 85 | 86 | elif args.resume: 87 | # load config from checkpoint if new config file is not given. 88 | # Use '--config' and '--resume' together to fine-tune trained model with changed configurations. 89 | config = torch.load(args.resume)['config'] 90 | 91 | else: 92 | raise AssertionError( 93 | "Configuration file need to be specified. Add '-c config.yaml', for example.") 94 | 95 | if args.device: 96 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 97 | 98 | wandb.init(project='deep-sleep') 99 | wandb.config.update(config.toDict()) 100 | 101 | main(config, args.resume) 102 | 103 | wandb.run.summary.update() 104 | -------------------------------------------------------------------------------- /trained_models/best_weights.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neergaard/deep-sleep-pytorch/db5d6f59369baacabeac956201c3c6de09961945/trained_models/best_weights.pth --------------------------------------------------------------------------------