├── .gitignore ├── README.md ├── datasets ├── __init__.py └── ssc_wsc.py ├── models ├── __init__.py └── utime.py ├── train.py └── utils ├── __init__.py ├── arg_utils.py ├── dataset_utils.py ├── evaluate_performance.py ├── h5_utils.py ├── logger_callback_utils.py ├── losses.py ├── model_utils.py └── parallel_bar.py /.gitignore: -------------------------------------------------------------------------------- 1 | # File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig 2 | 3 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,jupyternotebooks,python 4 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,linux,jupyternotebooks,python 5 | 6 | ### JupyterNotebooks ### 7 | # gitignore template for Jupyter Notebooks 8 | # website: http://jupyter.org/ 9 | 10 | .ipynb_checkpoints 11 | */.ipynb_checkpoints/* 12 | 13 | # IPython 14 | profile_default/ 15 | ipython_config.py 16 | 17 | # Remove previous ipynb_checkpoints 18 | # git rm -r .ipynb_checkpoints/ 19 | 20 | ### Linux ### 21 | *~ 22 | 23 | # temporary files which can be created if a process still has a handle open of a deleted file 24 | .fuse_hidden* 25 | 26 | # KDE directory preferences 27 | .directory 28 | 29 | # Linux trash folder which might appear on any partition or disk 30 | .Trash-* 31 | 32 | # .nfs files are created when an open file is removed but is still being accessed 33 | .nfs* 34 | 35 | ### Python ### 36 | # Byte-compiled / optimized / DLL files 37 | __pycache__/ 38 | *.py[cod] 39 | *$py.class 40 | 41 | # C extensions 42 | *.so 43 | 44 | # Distribution / packaging 45 | .Python 46 | build/ 47 | develop-eggs/ 48 | dist/ 49 | downloads/ 50 | eggs/ 51 | .eggs/ 52 | lib/ 53 | lib64/ 54 | parts/ 55 | sdist/ 56 | var/ 57 | wheels/ 58 | pip-wheel-metadata/ 59 | share/python-wheels/ 60 | *.egg-info/ 61 | .installed.cfg 62 | *.egg 63 | MANIFEST 64 | 65 | # PyInstaller 66 | # Usually these files are written by a python script from a template 67 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 68 | *.manifest 69 | *.spec 70 | 71 | # Installer logs 72 | pip-log.txt 73 | pip-delete-this-directory.txt 74 | 75 | # Unit test / coverage reports 76 | htmlcov/ 77 | .tox/ 78 | .nox/ 79 | .coverage 80 | .coverage.* 81 | .cache 82 | nosetests.xml 83 | coverage.xml 84 | *.cover 85 | *.py,cover 86 | .hypothesis/ 87 | .pytest_cache/ 88 | pytestdebug.log 89 | 90 | # Translations 91 | *.mo 92 | *.pot 93 | 94 | # Django stuff: 95 | *.log 96 | local_settings.py 97 | db.sqlite3 98 | db.sqlite3-journal 99 | 100 | # Flask stuff: 101 | instance/ 102 | .webassets-cache 103 | 104 | # Scrapy stuff: 105 | .scrapy 106 | 107 | # Sphinx documentation 108 | docs/_build/ 109 | doc/_build/ 110 | 111 | # PyBuilder 112 | target/ 113 | 114 | # Jupyter Notebook 115 | 116 | # IPython 117 | 118 | # pyenv 119 | .python-version 120 | 121 | # pipenv 122 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 123 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 124 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 125 | # install all needed dependencies. 126 | #Pipfile.lock 127 | 128 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 129 | __pypackages__/ 130 | 131 | # Celery stuff 132 | celerybeat-schedule 133 | celerybeat.pid 134 | 135 | # SageMath parsed files 136 | *.sage.py 137 | 138 | # Environments 139 | .env 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | pythonenv* 147 | 148 | # Spyder project settings 149 | .spyderproject 150 | .spyproject 151 | 152 | # Rope project settings 153 | .ropeproject 154 | 155 | # mkdocs documentation 156 | /site 157 | 158 | # mypy 159 | .mypy_cache/ 160 | .dmypy.json 161 | dmypy.json 162 | 163 | # Pyre type checker 164 | .pyre/ 165 | 166 | # pytype static type analyzer 167 | .pytype/ 168 | 169 | # profiling data 170 | .prof 171 | 172 | ### VisualStudioCode ### 173 | .vscode/* 174 | !.vscode/tasks.json 175 | !.vscode/launch.json 176 | *.code-workspace 177 | 178 | ### VisualStudioCode Patch ### 179 | # Ignore all local history of files 180 | .history 181 | .ionide 182 | 183 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,jupyternotebooks,python 184 | 185 | # Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) 186 | checkpoints 187 | data 188 | experiments 189 | .vscode 190 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # utime-pytorch 2 | A PyTorch/PyTorch Lightning implementation of the [U-Time model](https://arxiv.org/abs/1910.11162). 3 | 4 | This repository is still under development; please see the original repository by the authors for the original working implementation [here](https://github.com/perslev/U-Time). 5 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .ssc_wsc import SscWscDataModule 2 | 3 | available_datasets = {"ssc-wsc": SscWscDataModule} 4 | -------------------------------------------------------------------------------- /datasets/ssc_wsc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import warnings 5 | from itertools import compress 6 | from pprint import pprint 7 | 8 | import numpy as np 9 | import torch 10 | import pytorch_lightning as pl 11 | from h5py import File 12 | from joblib import delayed 13 | from joblib import Memory 14 | from joblib import Parallel 15 | from sklearn import preprocessing 16 | from torch.utils.data import DataLoader 17 | from torch.utils.data import Dataset 18 | from torch.utils.data import Subset 19 | from tqdm import tqdm 20 | 21 | import utils 22 | 23 | # try: 24 | # from utils.parallel_bar import ParallelExecutor 25 | # from utils.h5_utils import load_h5_data 26 | # except ImportError: 27 | # from utils.h5_utils import load_h5_data 28 | # from utils.parallel_bar import ParallelExecutor 29 | 30 | warnings.filterwarnings("ignore", category=UserWarning, module="joblib") 31 | SCALERS = {"robust": preprocessing.RobustScaler, "standard": preprocessing.StandardScaler} 32 | 33 | 34 | def get_class_sequence_idx(hypnogram, selected_sequences): 35 | d = { 36 | "w": [idx for idx, hyp in enumerate(hypnogram) if (hyp == 0).any() and idx in selected_sequences], 37 | "n1": [idx for idx, hyp in enumerate(hypnogram) if (hyp == 1).any() and idx in selected_sequences], 38 | "n2": [idx for idx, hyp in enumerate(hypnogram) if (hyp == 2).any() and idx in selected_sequences], 39 | "n3": [idx for idx, hyp in enumerate(hypnogram) if (hyp == 3).any() and idx in selected_sequences], 40 | "r": [idx for idx, hyp in enumerate(hypnogram) if (hyp == 4).any() and idx in selected_sequences], 41 | } 42 | return d 43 | 44 | 45 | def get_unknown_stage(onehot_hypnogram): 46 | return onehot_hypnogram.sum(axis=1) == 0 47 | 48 | 49 | def get_stable_stage(hypnogram, stage, adjustment=30): 50 | """ 51 | Args: 52 | hypnogram (array_like): hypnogram with sleep stage labels 53 | stage (int): sleep stage label ({'W': 0, 'N1': 1, 'N2': 2, 'N3': 3, 'R': 4}) 54 | adjusted (int, optional): Controls the amount of bracketing surrounding a period of stable sleep. 55 | E.g. if adjustment=30, each period of stable sleep needs to be bracketed by 30 s. 56 | Returns: 57 | stable_periods: a list of range objects where each range describes a period of stable sleep stage. 58 | """ 59 | from itertools import groupby 60 | from operator import itemgetter 61 | 62 | list_of_periods = [] 63 | for k, g in groupby(enumerate(np.where(hypnogram == stage)[0]), lambda x: x[0] - x[1]): 64 | list_of_periods.append(list(map(itemgetter(1), g))) 65 | stable_periods = [range(period[0] + adjustment, period[-1] + 1 - adjustment) for period in list_of_periods] 66 | # Some periods are empty and need to be removed 67 | stable_periods = list(filter(lambda x: list(x), stable_periods)) 68 | 69 | return stable_periods 70 | 71 | 72 | def get_stable_sleep_periods(hypnogram, adjustment=30): 73 | """Get periods of stable sleep uninterrupted by transitions 74 | 75 | Args: 76 | hypnogram (array-like): hypnogram vector or array with sleep stage labels 77 | adjustment (int): parameter controlling the amount of shift when selecting periods of stable sleep. E.g. 78 | if adjustment = 30, each period of stable sleep needs to be bracketed by 30 s of the same sleep stage. 79 | """ 80 | hypnogram_shape = hypnogram.shape 81 | hypnogram = hypnogram.reshape(np.prod(hypnogram_shape)) 82 | stable_periods = [] 83 | stable_periods_bool = np.full(np.prod(hypnogram_shape), False) 84 | for stage in [0, 1, 2, 3, 4]: 85 | stable_periods.append(get_stable_stage(hypnogram, stage, adjustment)) 86 | for period in stable_periods[-1]: 87 | stable_periods_bool[period] = True 88 | stable_periods_bool = stable_periods_bool.reshape(hypnogram_shape) 89 | 90 | return stable_periods_bool, stable_periods 91 | 92 | 93 | def initialize_record(filename, scaling=None, overlap=True, adjustment=30): 94 | 95 | if scaling in SCALERS.keys(): 96 | scaler = SCALERS[scaling]() 97 | else: 98 | scaler = None 99 | 100 | with File(filename, "r") as h5: 101 | # if "A2081_5 194244.h5" in filename: 102 | # print("Hej") 103 | N, C, T = h5["M"].shape 104 | hypnogram = h5["L"][:, :, ::30] 105 | hyp_shape = hypnogram.shape 106 | sequences_in_file = N 107 | 108 | if scaler: 109 | scaler.fit(h5["M"][:].transpose(1, 0, 2).reshape((C, N * T)).T) 110 | 111 | # Remember that the output array from the H5 has 50 % overlap between segments. 112 | # Use the following to split into even and odd 113 | if overlap: 114 | hyp_even = hypnogram[0::2] 115 | hyp_odd = hypnogram[1::2] 116 | if adjustment > 0: 117 | stable_sleep = np.full([v for idx, v in enumerate(hyp_shape) if idx != 1], False) 118 | stable_sleep[0::2] = get_stable_sleep_periods(hyp_even.argmax(axis=1), adjustment)[0] 119 | stable_sleep[1::2] = get_stable_sleep_periods(hyp_odd.argmax(axis=1), adjustment)[0] 120 | else: 121 | stable_sleep = np.full([v for idx, v in enumerate(hyp_shape) if idx != 1], True) 122 | else: 123 | if adjustment > 0: 124 | stable_sleep = get_stable_sleep_periods(h5["L"][:].argmax(axis=1), adjustment)[0] 125 | else: 126 | stable_sleep = np.full([v for idx, v in enumerate(hyp_shape) if idx != 1], True) 127 | 128 | # Remove unknown stage 129 | unknown_stage = get_unknown_stage(hypnogram) 130 | stable_sleep[unknown_stage] = False 131 | 132 | # Get bin counts 133 | if overlap: 134 | # hyp = h5["L"][::2].argmax(axis=1)[~get_unknown_stage(h5["L"][::2])][::30] 135 | hyp = hyp_even.argmax(axis=1)[~unknown_stage[::2] & stable_sleep[::2]] 136 | else: 137 | # hyp = h5["L"][:].argmax(axis=1)[~get_unknown_stage(h5["L"][:])][::30] 138 | hyp = hypnogram.argmax(axis=1)[~unknown_stage & stable_sleep] 139 | bin_counts = np.bincount(hyp, minlength=C) 140 | 141 | return hypnogram.argmax(1), sequences_in_file, scaler, stable_sleep, bin_counts 142 | 143 | 144 | def load_psg_h5_data(filename, scaling=None): 145 | scaler = None 146 | 147 | if scaling: 148 | scaler = SCALERS[scaling]() 149 | 150 | with File(filename, "r") as h5: 151 | N, C, T = h5["M"].shape 152 | sequences_in_file = N 153 | 154 | if scaling: 155 | scaler.fit(h5["M"][:].transpose(1, 0, 2).reshape((C, N * T)).T) 156 | 157 | return sequences_in_file, scaler 158 | # X = h5['M'][:].astype('float32') 159 | # y = h5['L'][:].astype('float32') 160 | 161 | # sequences_in_file = X.shape[0] 162 | 163 | # return X, y, sequences_in_file 164 | 165 | 166 | class SscWscPsgDataset(Dataset): 167 | def __init__( 168 | self, 169 | data_dir=None, 170 | n_jobs=-1, 171 | scaling=None, 172 | adjustment=30, 173 | n_records=None, 174 | overlap=True, 175 | beta=0.999, 176 | cv=None, 177 | cv_idx=None, 178 | eval_ratio=None, 179 | ): 180 | super().__init__() 181 | self.data_dir = data_dir 182 | self.n_jobs = n_jobs 183 | self.scaling = scaling 184 | self.adjustment = adjustment 185 | self.n_records = n_records 186 | self.overlap = overlap 187 | self.beta = beta 188 | self.cv = cv 189 | self.cv_idx = cv_idx 190 | self.eval_ratio = eval_ratio 191 | 192 | self.records = sorted(os.listdir(self.data_dir))[: self.n_records] 193 | # self.data = {r: [] for r in self.records} 194 | self.index_to_record = [] 195 | self.index_to_record_class = {"w": [], "n1": [], "n2": [], "n3": [], "r": []} 196 | # self.record_to_index = [] 197 | self.record_indices = {r: None for r in self.records} 198 | self.scalers = {r: None for r in self.records} 199 | self.stable_sleep = {r: None for r in self.records} 200 | self.record_class_indices = {r: None for r in self.records} 201 | # self.batch_indices = [] 202 | # self.current_record_idx = -1 203 | # self.current_record = None 204 | # self.loaded_record = None 205 | # self.current_position = None 206 | # data = load_psg_h5_data(os.path.join(self.data_dir, self.records[0])) 207 | self.cache_dir = "data/.cache" 208 | memory = Memory(self.cache_dir, mmap_mode="r", verbose=0) 209 | get_data = memory.cache(initialize_record) 210 | 211 | # Get information about the data 212 | print(f"Loading mmap data using {n_jobs} workers:") 213 | data = utils.ParallelExecutor(n_jobs=n_jobs, prefer="threads")(total=len(self.records))( 214 | delayed(get_data)( 215 | filename=os.path.join(self.data_dir, record), 216 | scaling=self.scaling, 217 | adjustment=self.adjustment, 218 | overlap=self.overlap, 219 | ) 220 | for record in self.records 221 | ) 222 | # for record, d in zip(tqdm(self.records, desc='Processing'), data): 223 | # seqs_in_file = d[2] 224 | # self.data[record] = {'data': d[0], 'target': d[1]} 225 | self.n_classes = 5 226 | cum_class_counts = np.zeros(self.n_classes, dtype=np.int64) 227 | for record, (hypnogram, sequences_in_file, scaler, stable_sleep, class_counts) in zip( 228 | tqdm(self.records, desc="Processing"), data 229 | ): 230 | # Some sequences are all unstable sleep, which interferes with the loss calculations. 231 | # This selects sequences where at least one epoch is sleep. 232 | select_sequences = np.where(stable_sleep.squeeze().any(axis=1))[0] 233 | self.record_indices[record] = select_sequences # np.arange(sequences_in_file) 234 | self.record_class_indices[record] = get_class_sequence_idx(hypnogram, select_sequences) 235 | self.index_to_record.extend( 236 | [{"record": record, "idx": x} for x in select_sequences] 237 | ) # range(sequences_in_file)]) 238 | for c in self.index_to_record_class.keys(): 239 | self.index_to_record_class[c].extend( 240 | [ 241 | { 242 | "idx": [ 243 | idx 244 | for idx, i2r in enumerate(self.index_to_record) 245 | if i2r["idx"] == x and record == i2r["record"] 246 | ][0], 247 | "record": record, 248 | "record_idx": x, 249 | } 250 | for x in self.record_class_indices[record][c] 251 | ] 252 | ) 253 | self.scalers[record] = scaler 254 | self.stable_sleep[record] = stable_sleep 255 | cum_class_counts += class_counts 256 | 257 | # Define the class-balanced weights. We normalize the class counts to the lowest value as the numerator 258 | # otherwise will dominate the expression 259 | # (see https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf) 260 | self.cb_weights_norm = (1 - self.beta) / (1 - self.beta ** (cum_class_counts / cum_class_counts.min())) 261 | self.effective_samples = 1 / self.cb_weights_norm 262 | self.cb_weights = self.cb_weights_norm * self.n_classes / self.cb_weights_norm.sum() 263 | print("") 264 | print(f"Class counts: {cum_class_counts}") 265 | print(f"Beta: {self.beta}") 266 | print(f"CB weights norm: {self.cb_weights_norm}") 267 | print(f"Effective samples: {self.effective_samples}") 268 | print(f"CB weights: {self.cb_weights}") 269 | print("") 270 | print("Finished loading data") 271 | 272 | def shuffle_records(self): 273 | random.shuffle(self.records) 274 | 275 | def split_data(self): 276 | n_records = len(self.records) 277 | self.shuffle_records() 278 | # if self.cv is None: 279 | n_eval = int(n_records * self.eval_ratio) 280 | n_train = n_records - n_eval 281 | train_idx = np.arange(n_eval, n_records) 282 | eval_idx = np.arange(0, n_eval) 283 | # train_data = SscWscPsgSubset(self, np.arange(n_eval, n_records), name="Train") 284 | # eval_data = SscWscPsgSubset(self, np.arange(0, n_eval), name="Validation") 285 | # else: 286 | if self.cv: 287 | # from sklearn.model_selection import KFold, StratifiedKFold 288 | 289 | # kf = KFold(n_splits=np.abs(self.cv)) 290 | ssc_idx = ["SSC" in s for s in np.array(self.records)[train_idx]] 291 | kf = sklearn.model_selection.StratifiedKFold(n_splits=np.abs(self.cv)) 292 | if self.cv > 0: 293 | # train_idx, eval_idx = list(kf.split(np.arange(n_records)))[self.cv_idx] 294 | _, train_idx = list(kf.split(train_idx, ssc_idx))[self.cv_idx] 295 | else: 296 | # eval_idx, train_idx = list(kf.split(np.arange(n_records)))[self.cv_idx] 297 | # train_idx, _ = list(kf.split(np.arange(n_records)))[self.cv_idx] 298 | train_idx, _ = list(kf.split(train_idx, ssc_idx))[self.cv_idx] 299 | print("\n") 300 | print(f"Running {np.abs(self.cv)}-fold cross-validation procedure.") 301 | print(f"Current split: {self.cv_idx}") 302 | print(f"Eval record indices: {eval_idx}") 303 | print(f"Train record indices: {train_idx}") 304 | print(f"Number of train/eval records: {len(train_idx)}/{len(eval_idx)}") 305 | print("\n") 306 | train_data = SscWscPsgSubset(self, train_idx, name="Train") 307 | eval_data = SscWscPsgSubset(self, eval_idx, name="Validation") 308 | 309 | return train_data, eval_data 310 | 311 | def __len__(self): 312 | # if isinstance(self.index_to_record, dict): 313 | # return sum([len(v) for v in self.index_to_record.values()]) 314 | # else: 315 | return len(self.index_to_record) 316 | 317 | def __getitem__(self, idx): 318 | 319 | try: 320 | # Grab data 321 | current_record = self.index_to_record[idx]["record"] 322 | current_sequence = self.index_to_record[idx]["idx"] 323 | scaler = self.scalers[current_record] 324 | stable_sleep = np.array(self.stable_sleep[current_record][current_sequence]).squeeze() 325 | 326 | # Grab data 327 | with File(os.path.join(self.data_dir, current_record), "r") as f: 328 | x = f["M"][current_sequence].astype("float32") 329 | t = f["L"][current_sequence, :, ::30].astype("uint8").squeeze() 330 | # x = self.data[current_record]['data'][current_sequence] 331 | # t = self.data[current_record]['target'][current_sequence] 332 | 333 | except IndexError: 334 | print("Bug") 335 | 336 | if np.isnan(x).any(): 337 | print("NaNs detected!") 338 | 339 | if scaler: 340 | x = scaler.transform(x.T).T # (n_channels, n_samples) 341 | 342 | return x, t, current_record, current_sequence, stable_sleep 343 | 344 | def __str__(self): 345 | s = f""" 346 | ====================================== 347 | SSC-WSC PSG Dataset Dataset 348 | -------------------------------------- 349 | Data directory: {self.data_dir} 350 | Number of records: {len(self.records)} 351 | ====================================== 352 | """ 353 | 354 | return s 355 | 356 | 357 | # def collate_fn(batch): 358 | # 359 | # x, t, w = ( 360 | # np.stack([b[0] for b in batch]), 361 | # np.stack([b[1] for b in batch]), 362 | # np.stack([b[2] for b in batch]) 363 | # ) 364 | # 365 | # return torch.FloatTensor(x), torch.IntTensor(t), torch.FloatTensor(w) 366 | 367 | 368 | def collate_fn(batch): 369 | 370 | X, y = map(torch.FloatTensor, zip(*batch)) 371 | 372 | return X, y 373 | # return torch.FloatTensor(X), torch.FloatTensor(y), torch.FloatTensor(w) 374 | 375 | 376 | class SscWscPsgSubset(Dataset): 377 | def __init__(self, dataset, record_indices, name="Train"): 378 | self.dataset = dataset 379 | self.record_indices = record_indices 380 | self.name = name 381 | self.records = [self.dataset.records[idx] for idx in self.record_indices] 382 | if self.name.lower() == "train": 383 | self.sequence_indices = self._get_subset_class_indices() 384 | else: 385 | self.sequence_indices = self._get_subset_indices() 386 | # self.sequence_indices = [idx for idx, v in enumerate(self.dataset.index_to_record) for r in self.records if v['record'] == r]# [idx for idx, v in enumerate(self.dataset.index_to_record) for r in self.records if v['record'] == r] 387 | # print("BAAAD BOIII") 388 | 389 | def _get_subset_class_indices(self): 390 | out = {k: None for k in self.dataset.index_to_record_class.keys()} 391 | for c in out.keys(): 392 | t = list(map(lambda x: x["record"] in self.records, self.dataset.index_to_record_class[c])) 393 | out[c] = list(compress(range(len(t)), t)) 394 | return out 395 | 396 | def _get_subset_indices(self): 397 | t = list(map(lambda x: x["record"] in self.records, self.dataset.index_to_record)) 398 | return list(compress(range(len(t)), t)) 399 | 400 | def __getitem__(self, idx): 401 | if isinstance(self.sequence_indices, dict): 402 | class_choice = np.random.choice(list(self.sequence_indices.keys())) 403 | sequence_choice = np.random.choice(self.sequence_indices[class_choice]) 404 | return self.dataset[self.dataset.index_to_record_class[class_choice][sequence_choice]["idx"]] 405 | else: 406 | return self.dataset[self.sequence_indices[idx]] 407 | 408 | def __len__(self): 409 | if isinstance(self.sequence_indices, dict): 410 | return sum([len(v) for v in self.sequence_indices.values()]) 411 | else: 412 | return len(self.sequence_indices) 413 | 414 | def __str__(self): 415 | s = f""" 416 | ====================================== 417 | SSC-WSC PSG Dataset - {self.name} partition 418 | -------------------------------------- 419 | Data directory: {self.dataset.data_dir} 420 | Number of records: {len(self.record_indices)} 421 | First ten records: {self.records[:10]} 422 | ====================================== 423 | """ 424 | 425 | return s 426 | 427 | 428 | class SscWscDataModule(pl.LightningDataModule): 429 | def __init__( 430 | self, 431 | batch_size, 432 | cv=None, 433 | cv_idx=None, 434 | data_dir=None, 435 | eval_ratio=0.1, 436 | n_workers=0, 437 | n_jobs=-1, 438 | n_records=None, 439 | scaling="robust", 440 | adjustment=None, 441 | **kwargs, 442 | ): 443 | super().__init__() 444 | self.adjustment = adjustment 445 | self.data_dir = data_dir 446 | self.batch_size = batch_size 447 | self.cv = cv 448 | self.cv_idx = cv_idx 449 | self.eval_ratio = eval_ratio 450 | self.n_jobs = n_jobs 451 | self.n_records = n_records 452 | self.n_workers = n_workers 453 | self.scaling = scaling 454 | self.data = {"train": os.path.join(data_dir, "train"), "test": os.path.join(data_dir, "test")} 455 | self.dataset_params = dict( 456 | # data_dir=self.data_dir, 457 | cv=self.cv, 458 | cv_idx=self.cv_idx, 459 | eval_ratio=self.eval_ratio, 460 | n_jobs=self.n_jobs, 461 | n_records=self.n_records, 462 | scaling=self.scaling, 463 | adjustment=self.adjustment, 464 | ) 465 | 466 | def setup(self, stage="fit"): 467 | if stage == "fit": 468 | dataset = SscWscPsgDataset(data_dir=self.data["train"], **self.dataset_params) 469 | self.train, self.eval = dataset.split_data() 470 | elif stage == "test": 471 | self.test = SscWscPsgDataset(data_dir=self.data["test"], overlap=False, **self.dataset_params) 472 | 473 | def train_dataloader(self): 474 | """Return training dataloader.""" 475 | return torch.utils.data.DataLoader( 476 | self.train, 477 | batch_size=self.batch_size, 478 | shuffle=True, 479 | num_workers=self.n_workers, 480 | pin_memory=True, 481 | # drop_last=True, 482 | ) 483 | 484 | def val_dataloader(self): 485 | """Return validation dataloader.""" 486 | return torch.utils.data.DataLoader( 487 | self.eval, 488 | batch_size=self.batch_size, 489 | shuffle=False, 490 | num_workers=self.n_workers, 491 | pin_memory=True, 492 | # drop_last=True, 493 | ) 494 | 495 | def test_dataloader(self): 496 | """Return test dataloader.""" 497 | return torch.utils.data.DataLoader( 498 | self.test, batch_size=self.batch_size, shuffle=False, num_workers=self.n_workers, pin_memory=True, 499 | ) 500 | 501 | @staticmethod 502 | def add_dataset_specific_args(parent_parser): 503 | from argparse import ArgumentParser 504 | 505 | # DATASET specific 506 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 507 | dataset_group = parser.add_argument_group("dataset") 508 | dataset_group.add_argument("--data_dir", default="data/ssc_wsc/5min", type=str) 509 | dataset_group.add_argument("--eval_ratio", default=0.1, type=float) 510 | dataset_group.add_argument("--n_jobs", default=-1, type=int) 511 | dataset_group.add_argument("--n_records", default=None, type=int) 512 | dataset_group.add_argument("--scaling", default="robust", type=str) 513 | dataset_group.add_argument("--adjustment", default=0, type=int) 514 | dataset_group.add_argument("--cv", default=None, type=int) 515 | dataset_group.add_argument("--cv_idx", default=None, type=int) 516 | 517 | # DATALOADER specific 518 | dataloader_group = parser.add_argument_group("dataloader") 519 | dataloader_group.add_argument("--batch_size", default=12, type=int) 520 | dataloader_group.add_argument("--n_workers", default=20, type=int) 521 | 522 | return parser 523 | 524 | 525 | if __name__ == "__main__": 526 | 527 | np.random.seed(42) 528 | random.seed(42) 529 | 530 | # dataset_params = dict(data_dir="./data/raw/individual_encodings", n_jobs=1, scaling="robust", n_records=10) 531 | # dataset_params = dict(data_dir="./data/ssc_wsc/raw/5min", n_jobs=1, scaling="robust", n_records=10,) 532 | # dataset = SscWscPsgDataset(**dataset_params) 533 | dm_params = dict( 534 | batch_size=32, 535 | n_workers=0, 536 | data_dir="./data/ssc_wsc/5min", 537 | eval_ratio=0.1, 538 | n_records=None, 539 | scaling="robust", 540 | adjustment=15, 541 | n_jobs=-1, 542 | ) 543 | dm = SscWscDataModule(**dm_params) 544 | dm.setup("fit") 545 | print(dm.train) 546 | pbar = tqdm(dm.train_dataloader()) 547 | for idx, batch in enumerate(pbar): 548 | if idx == 0: 549 | print(batch) 550 | # pbar = tqdm(DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0, pin_memory=True)) 551 | # for idx, (x, t) in enumerate(pbar): 552 | # if idx == 0: 553 | # print(x.shape) 554 | # print(eval_data) 555 | # pbar = tqdm(DataLoader(eval_data, batch_size=32, shuffle=True, num_workers=20, pin_memory=True)) 556 | # for x, t in pbar: 557 | # pass 558 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .utime import UTimeModel 2 | 3 | available_models = {"utime": UTimeModel} 4 | -------------------------------------------------------------------------------- /models/utime.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | # import matplotlib.pyplot as plt 4 | import torch 5 | import torch.distributed as dist 6 | import torch.nn as nn 7 | # import wandb 8 | #from pytorch_lightning import EvalResult, TrainResult 9 | from pytorch_lightning import LightningModule 10 | from pytorch_lightning.metrics import Accuracy 11 | from sklearn import metrics 12 | from tqdm import tqdm 13 | 14 | import utils 15 | # from utils.plotting import plot_segment 16 | 17 | 18 | class ConvBNReLU(nn.Module): 19 | def __init__(self, in_channels=5, out_channels=5, kernel_size=3, dilation=1, activation="relu"): 20 | super().__init__() 21 | self.in_channels = in_channels 22 | self.out_channels = out_channels 23 | self.kernel_size = kernel_size 24 | self.dilation = dilation 25 | self.activation = activation 26 | self.padding = (self.kernel_size + (self.kernel_size - 1) * (self.dilation - 1) - 1) // 2 27 | 28 | self.layers = nn.Sequential( 29 | nn.ConstantPad1d(padding=(self.padding, self.padding), value=0), 30 | nn.Conv1d( 31 | in_channels=self.in_channels, 32 | out_channels=self.out_channels, 33 | kernel_size=self.kernel_size, 34 | dilation=self.dilation, 35 | bias=True, 36 | ), 37 | nn.ReLU(), 38 | nn.BatchNorm1d(self.out_channels), 39 | ) 40 | nn.init.xavier_uniform_(self.layers[1].weight) 41 | nn.init.zeros_(self.layers[1].bias) 42 | 43 | def forward(self, x): 44 | return self.layers(x) 45 | 46 | 47 | class Encoder(nn.Module): 48 | def __init__(self, filters=[16, 32, 64, 128], in_channels=5, maxpool_kernels=[10, 8, 6, 4], kernel_size=5, dilation=2): 49 | super().__init__() 50 | self.filters = filters 51 | self.in_channels = in_channels 52 | self.maxpool_kernels = maxpool_kernels 53 | self.kernel_size = kernel_size 54 | self.dilation = dilation 55 | assert len(self.filters) == len( 56 | self.maxpool_kernels 57 | ), f"Number of filters ({len(self.filters)}) does not equal number of supplied maxpool kernels ({len(self.maxpool_kernels)})!" 58 | 59 | self.depth = len(self.filters) 60 | 61 | # fmt: off 62 | self.blocks = nn.ModuleList([nn.Sequential( 63 | ConvBNReLU( 64 | in_channels=self.in_channels if k == 0 else self.filters[k - 1], 65 | out_channels=self.filters[k], 66 | kernel_size=self.kernel_size, 67 | dilation=self.dilation, 68 | activation="relu", 69 | ), 70 | ConvBNReLU( 71 | in_channels=self.filters[k], 72 | out_channels=self.filters[k], 73 | kernel_size=self.kernel_size, 74 | dilation=self.dilation, 75 | activation="relu", 76 | ), 77 | ) for k in range(self.depth)]) 78 | # fmt: on 79 | 80 | self.maxpools = nn.ModuleList([nn.MaxPool1d(self.maxpool_kernels[k]) for k in range(self.depth)]) 81 | 82 | self.bottom = nn.Sequential( 83 | ConvBNReLU( 84 | in_channels=self.filters[-1], 85 | out_channels=self.filters[-1] * 2, 86 | kernel_size=self.kernel_size, 87 | ), 88 | ConvBNReLU( 89 | in_channels=self.filters[-1] * 2, 90 | out_channels=self.filters[-1] * 2, 91 | kernel_size=self.kernel_size 92 | ), 93 | ) 94 | 95 | def forward(self, x): 96 | shortcuts = [] 97 | for layer, maxpool in zip(self.blocks, self.maxpools): 98 | z = layer(x) 99 | shortcuts.append(z) 100 | x = maxpool(z) 101 | 102 | # Bottom part 103 | encoded = self.bottom(x) 104 | 105 | return encoded, shortcuts 106 | 107 | 108 | class Decoder(nn.Module): 109 | def __init__(self, filters=[128, 64, 32, 16], upsample_kernels=[4, 6, 8, 10], in_channels=256, out_channels=5, kernel_size=5): 110 | super().__init__() 111 | self.filters = filters 112 | self.upsample_kernels = upsample_kernels 113 | self.in_channels = in_channels 114 | self.kernel_size = kernel_size 115 | self.out_channels = out_channels 116 | assert len(self.filters) == len( 117 | self.upsample_kernels 118 | ), f"Number of filters ({len(self.filters)}) does not equal number of supplied upsample kernels ({len(self.upsample_kernels)})!" 119 | self.depth = len(self.filters) 120 | 121 | # fmt: off 122 | self.upsamples = nn.ModuleList([nn.Sequential( 123 | nn.Upsample(scale_factor=self.upsample_kernels[k]), 124 | ConvBNReLU( 125 | in_channels=self.in_channels if k == 0 else self.filters[k - 1], 126 | out_channels=self.filters[k], 127 | kernel_size=self.kernel_size, 128 | activation='relu', 129 | ) 130 | ) for k in range(self.depth)]) 131 | 132 | self.blocks = nn.ModuleList([nn.Sequential( 133 | ConvBNReLU( 134 | in_channels=self.in_channels if k == 0 else self.filters[k - 1], 135 | out_channels=self.filters[k], 136 | kernel_size=self.kernel_size, 137 | ), 138 | ConvBNReLU( 139 | in_channels=self.filters[k], 140 | out_channels=self.filters[k], 141 | kernel_size=self.kernel_size, 142 | ), 143 | ) for k in range(self.depth)]) 144 | # fmt: off 145 | 146 | def forward(self, z, shortcuts): 147 | 148 | for upsample, block, shortcut in zip(self.upsamples, self.blocks, shortcuts[::-1]): 149 | z = upsample(z) 150 | z = torch.cat([shortcut, z], dim=1) 151 | z = block(z) 152 | 153 | return z 154 | 155 | 156 | class SegmentClassifier(nn.Module): 157 | def __init__(self, sampling_frequency=128, num_classes=5, epoch_length=30): 158 | super().__init__() 159 | # self.sampling_frequency = sampling_frequency 160 | self.num_classes = num_classes 161 | # self.epoch_length = epoch_length 162 | self.layers = nn.Sequential( 163 | # nn.AvgPool2d(kernel_size=(1, self.epoch_length * self.sampling_frequency)), 164 | # nn.Flatten(start_dim=2), 165 | # nn.ConstantPad1d(padding=(self.padding, self.padding), value=0), 166 | nn.Conv1d(in_channels=self.num_classes, out_channels=self.num_classes, kernel_size=1), 167 | nn.Softmax(dim=1), 168 | ) 169 | nn.init.xavier_uniform_(self.layers[0].weight) 170 | nn.init.zeros_(self.layers[0].bias) 171 | 172 | def forward(self, x): 173 | # batch_size, num_classes, n_samples = x.shape 174 | # z = x.reshape((batch_size, num_classes, -1, self.epoch_length * self.sampling_frequency)) 175 | return self.layers(x) 176 | 177 | 178 | class UTimeModel(LightningModule): 179 | # def __init__( 180 | # self, filters=[16, 32, 64, 128], in_channels=5, maxpool_kernels=[10, 8, 6, 4], kernel_size=5, 181 | # dilation=2, sampling_frequency=128, num_classes=5, epoch_length=30, lr=1e-4, batch_size=12, 182 | # n_workers=0, eval_ratio=0.1, data_dir=None, n_jobs=-1, n_records=-1, scaling=None, **kwargs 183 | # ): 184 | def __init__( 185 | self, 186 | filters=None, 187 | in_channels=None, 188 | maxpool_kernels=None, 189 | kernel_size=None, 190 | dilation=None, 191 | num_classes=None, 192 | sampling_frequency=None, 193 | epoch_length=None, 194 | data_dir=None, 195 | n_jobs=None, 196 | n_records=None, 197 | scaling=None, 198 | lr=None, 199 | n_segments=10, 200 | *args, 201 | **kwargs 202 | ): 203 | # def __init__(self, *args, **kwargs): 204 | super().__init__() 205 | self.save_hyperparameters() 206 | # self.save_hyperparameters(hparams) 207 | # self.save_hyperparameters({k: v for k, v in hparams.items() if not callable(v)}) 208 | self.encoder = Encoder( 209 | filters=self.hparams.filters, 210 | in_channels=self.hparams.in_channels, 211 | maxpool_kernels=self.hparams.maxpool_kernels, 212 | kernel_size=self.hparams.kernel_size, 213 | dilation=self.hparams.dilation, 214 | ) 215 | self.decoder = Decoder( 216 | filters=self.hparams.filters[::-1], 217 | upsample_kernels=self.hparams.maxpool_kernels[::-1], 218 | in_channels=self.hparams.filters[-1] * 2, 219 | kernel_size=self.hparams.kernel_size, 220 | ) 221 | self.dense = nn.Sequential( 222 | nn.Conv1d(in_channels=self.hparams.filters[0], out_channels=self.hparams.num_classes, kernel_size=1, bias=True), 223 | nn.Tanh() 224 | ) 225 | nn.init.xavier_uniform_(self.dense[0].weight) 226 | nn.init.zeros_(self.dense[0].bias) 227 | self.segment_classifier = SegmentClassifier( 228 | sampling_frequency=self.hparams.sampling_frequency, 229 | num_classes=self.hparams.num_classes, 230 | epoch_length=self.hparams.epoch_length 231 | ) 232 | self.loss = utils.DiceLoss(self.hparams.num_classes) 233 | # self.loss = nn.CrossEntropyLoss() 234 | # self.train_acc = Accuracy() 235 | # self.eval_acc = Accuracy() 236 | # self.metric = Accuracy(num_classes=self.hparams.num_classes, reduce_op='mean') 237 | 238 | # Create Dataset params 239 | self.dataset_params = dict( 240 | data_dir=self.hparams.data_dir, 241 | n_jobs=self.hparams.n_jobs, 242 | n_records=self.hparams.n_records, 243 | scaling=self.hparams.scaling, 244 | ) 245 | 246 | # # Create DataLoader params 247 | # self.dataloader_params = dict( 248 | # batch_size=self.hparams.batch_size, 249 | # num_workers=self.hparams.n_workers, 250 | # pin_memory=True, 251 | # ) 252 | 253 | # Create Optimizer params 254 | self.optimizer_params = dict(lr=self.hparams.lr) 255 | # self.example_input_array = torch.zeros(1, self.hparams.in_channels, 35 * 30 * 100) 256 | 257 | def forward(self, x): 258 | 259 | # Run through encoder 260 | z, shortcuts = self.encoder(x) 261 | 262 | # Run through decoder 263 | z = self.decoder(z, shortcuts) 264 | 265 | # Run dense modeling 266 | z = self.dense(z) 267 | 268 | return z 269 | 270 | def classify_segments(self, x, resolution=30): 271 | 272 | # Run through encoder + decoder 273 | z = self(x) 274 | 275 | # Classify decoded samples 276 | resolution_samples = self.hparams.sampling_frequency * resolution 277 | z = z.unfold(-1, resolution_samples, resolution_samples) \ 278 | .mean(dim=-1) 279 | y = self.segment_classifier(z) 280 | 281 | return y 282 | 283 | def training_step(self, batch, batch_idx): 284 | # if batch_idx == 100: 285 | # print('hej') 286 | x, t, r, seq, stable_sleep = batch 287 | 288 | # Classify segments 289 | y = self.classify_segments(x) 290 | 291 | # loss = self.loss(y, t[:, :, ::self.hparams.epoch_length]) 292 | loss = self.compute_loss(y, t, stable_sleep) 293 | # self.train_acc(y.argmax(dim=1)[ss], t_.argmax(dim=1)[ss]) 294 | # accuracy = self.metric(y.argmax(dim=1), t[:, :, ::self.hparams.epoch_length].argmax(dim=1)) 295 | self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) 296 | # self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True) 297 | return { 298 | 'loss': loss, 299 | 'predicted': y, 300 | 'true': t, 301 | 'record': r, 302 | 'sequence_nr': seq, 303 | 'stable_sleep': stable_sleep 304 | } 305 | # result = TrainResult(minimize=loss) 306 | # result.log('train_loss', loss, prog_bar=True, sync_dist=True) 307 | # result.log('train_acc', accuracy, prog_bar=True, sync_dist=True) 308 | # return result 309 | 310 | def validation_step(self, batch, batch_idx): 311 | x, t, r, seq, stable_sleep = batch 312 | 313 | # Classify segments 314 | y = self.classify_segments(x) 315 | 316 | # loss = self.loss(y, t[:, :, ::self.hparams.epoch_length]) 317 | loss = self.compute_loss(y, t, stable_sleep) 318 | # self.eval_acc(y.argmax(dim=1)[ss], t_.argmax(dim=1)[ss]) 319 | # accuracy = self.metric(y.argmax(dim=1), t[:, :, ::self.hparams.epoch_length].argmax(dim=1)) 320 | self.log('eval_loss', loss, on_epoch=True, on_step=False, prog_bar=True, logger=True, sync_dist=True) 321 | # self.log('eval_acc', self.eval_acc, on_epoch=True, on_step=False, prog_bar=True, logger=True) 322 | return { 323 | 'loss': loss, 324 | 'predicted': y, 325 | 'true': t, 326 | 'record': r, 327 | 'sequence_nr': seq, 328 | 'stable_sleep': stable_sleep, 329 | } 330 | # # Generate an image 331 | # if batch_idx == 0: 332 | # fig = plot_segment(x, t, z) 333 | # self.logger.experiment[1].log({'Hypnodensity': wandb.Image(fig)}) 334 | # plt.close(fig) 335 | 336 | # result = EvalResult(checkpoint_on=loss) 337 | # result.log('eval_loss', loss, prog_bar=True, sync_dist=True) 338 | # result.log('eval_acc', accuracy, prog_bar=True, sync_dist=True) 339 | # 340 | # return result 341 | 342 | def test_step(self, batch, batch_index): 343 | 344 | X, t, current_record, current_sequence, stable_sleep = batch 345 | y = self.classify_segments(X) 346 | y_1s = self.classify_segments(X, resolution=1) 347 | # result = ptl.EvalResult() 348 | # result.predicted = y_hat.softmax(dim=1) 349 | # result.true = y 350 | # result.record = current_record 351 | # result.sequence_nr = current_sequence 352 | # result.stable_sleep = stable_sleep 353 | # return result 354 | return { 355 | "predicted": y, 356 | "true": t, 357 | "record": current_record, 358 | "sequence_nr": current_sequence, 359 | "stable_sleep": stable_sleep, 360 | 'logits': y_1s 361 | } 362 | 363 | def training_epoch_end(self, outputs): 364 | 365 | true = torch.cat([out['true'] for out in outputs], dim=0).permute([0, 2, 1]) 366 | predicted = torch.cat([out['predicted'] for out in outputs], dim=0).permute([0, 2, 1]) 367 | stable_sleep = torch.cat([out['stable_sleep'].to(torch.int64) for out in outputs], dim=0) 368 | sequence_nrs = torch.cat([out['sequence_nr'] for out in outputs], dim=0) 369 | 370 | if self.use_ddp: 371 | out_true = [torch.zeros_like(true) for _ in range(torch.distributed.get_world_size())] 372 | out_predicted = [torch.zeros_like(predicted) for _ in range(torch.distributed.get_world_size())] 373 | out_stable_sleep = [torch.zeros_like(stable_sleep) for _ in range(torch.distributed.get_world_size())] 374 | out_seq_nrs = [torch.zeros_like(sequence_nrs) for _ in range(dist.get_world_size())] 375 | dist.barrier() 376 | dist.all_gather(out_true, true) 377 | dist.all_gather(out_predicted, predicted) 378 | dist.all_gather(out_stable_sleep, stable_sleep) 379 | dist.all_gather(out_seq_nrs, sequence_nrs) 380 | if dist.get_rank() == 0: 381 | t = torch.stack(out_true).transpose(0, 1).reshape(-1, int(300/self.hparams.eval_frequency_sec), self.hparams.n_classes).cpu().numpy() 382 | p = torch.stack(out_predicted).transpose(0, 1).reshape(-1, int(300/self.hparams.eval_frequency_sec), self.hparams.n_classes).cpu().numpy() 383 | s = torch.stack(out_stable_sleep).transpose(0, 1).reshape(-1, int(300/self.hparams.eval_frequency_sec)).to(torch.bool).cpu().numpy() 384 | u = t.sum(axis=-1) == 1 385 | 386 | acc = metrics.accuracy_score(t[s & u].argmax(-1), p[s & u].argmax(-1)) 387 | cohen = metrics.cohen_kappa_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4]) 388 | f1_macro = metrics.f1_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4], average='macro') 389 | 390 | self.log_dict({ 391 | 'train_acc': acc, 392 | 'train_cohen': cohen, 393 | 'train_f1_macro': f1_macro, 394 | }, on_step=False, prog_bar=True, logger=True, on_epoch=True) 395 | elif self.on_gpu: 396 | t = true.cpu().numpy() 397 | p = predicted.cpu().detach().numpy() 398 | s = stable_sleep.to(torch.bool).cpu().numpy() 399 | u = t.sum(axis=-1) == 1 400 | 401 | acc = metrics.accuracy_score(t[s & u].argmax(-1), p[s & u].argmax(-1)) 402 | cohen = metrics.cohen_kappa_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4]) 403 | f1_macro = metrics.f1_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4], average='macro') 404 | 405 | self.log_dict({ 406 | 'train_acc': acc, 407 | 'train_cohen': cohen, 408 | 'train_f1_macro': f1_macro, 409 | }, on_step=False, prog_bar=True, logger=True, on_epoch=True) 410 | else: 411 | t = true.numpy() 412 | p = predicted.numpy() 413 | s = stable_sleep.to(torch.bool).numpy() 414 | u = t.sum(axis=-1) == 1 415 | 416 | acc = metrics.accuracy_score(t[s & u].argmax(-1), p[s & u].argmax(-1)) 417 | cohen = metrics.cohen_kappa_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4]) 418 | f1_macro = metrics.f1_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4], average='macro') 419 | 420 | self.log_dict({ 421 | 'train_acc': acc, 422 | 'train_cohen': cohen, 423 | 'train_f1_macro': f1_macro, 424 | }, on_step=False, prog_bar=True, logger=True, on_epoch=True) 425 | 426 | def validation_epoch_end(self, outputs): 427 | 428 | true = torch.cat([out['true'] for out in outputs], dim=0).permute([0, 2, 1]) 429 | predicted = torch.cat([out['predicted'] for out in outputs], dim=0).permute([0, 2, 1]) 430 | stable_sleep = torch.cat([out['stable_sleep'].to(torch.int64) for out in outputs], dim=0) 431 | sequence_nrs = torch.cat([out['sequence_nr'] for out in outputs], dim=0) 432 | 433 | if self.use_ddp: 434 | out_true = [torch.zeros_like(true) for _ in range(torch.distributed.get_world_size())] 435 | out_predicted = [torch.zeros_like(predicted) for _ in range(torch.distributed.get_world_size())] 436 | out_stable_sleep = [torch.zeros_like(stable_sleep) for _ in range(torch.distributed.get_world_size())] 437 | out_seq_nrs = [torch.zeros_like(sequence_nrs) for _ in range(dist.get_world_size())] 438 | dist.barrier() 439 | dist.all_gather(out_true, true) 440 | dist.all_gather(out_predicted, predicted) 441 | dist.all_gather(out_stable_sleep, stable_sleep) 442 | dist.all_gather(out_seq_nrs, sequence_nrs) 443 | if dist.get_rank() == 0: 444 | t = torch.stack(out_true).transpose(0, 1).reshape(-1, int(300/self.hparams.eval_frequency_sec), self.hparams.n_classes).cpu().numpy() 445 | p = torch.stack(out_predicted).transpose(0, 1).reshape(-1, int(300/self.hparams.eval_frequency_sec), self.hparams.n_classes).cpu().numpy() 446 | s = torch.stack(out_stable_sleep).transpose(0, 1).reshape(-1, int(300/self.hparams.eval_frequency_sec)).to(torch.bool).cpu().numpy() 447 | u = t.sum(axis=-1) == 1 448 | 449 | acc = metrics.accuracy_score(t[s & u].argmax(-1), p[s & u].argmax(-1)) 450 | cohen = metrics.cohen_kappa_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4]) 451 | f1_macro = metrics.f1_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4], average='macro') 452 | 453 | self.log_dict({ 454 | 'eval_acc': acc, 455 | 'eval_cohen': cohen, 456 | 'eval_f1_macro': f1_macro, 457 | }, on_step=False, prog_bar=True, logger=True, on_epoch=True) 458 | elif self.on_gpu: 459 | t = true.cpu().numpy() 460 | p = predicted.cpu().numpy() 461 | s = stable_sleep.to(torch.bool).cpu().numpy() 462 | u = t.sum(axis=-1) == 1 463 | 464 | acc = metrics.accuracy_score(t[s & u].argmax(-1), p[s & u].argmax(-1)) 465 | cohen = metrics.cohen_kappa_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4]) 466 | f1_macro = metrics.f1_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4], average='macro') 467 | 468 | self.log_dict({ 469 | 'eval_acc': acc, 470 | 'eval_cohen': cohen, 471 | 'eval_f1_macro': f1_macro, 472 | }, on_step=False, prog_bar=True, logger=True, on_epoch=True) 473 | else: 474 | t = true.numpy() 475 | p = predicted.numpy() 476 | s = stable_sleep.to(torch.bool).numpy() 477 | u = t.sum(axis=-1) == 1 478 | 479 | acc = metrics.accuracy_score(t[s & u].argmax(-1), p[s & u].argmax(-1)) 480 | cohen = metrics.cohen_kappa_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4]) 481 | f1_macro = metrics.f1_score(t[s & u].argmax(-1), p[s & u].argmax(-1), labels=[0, 1, 2, 3, 4], average='macro') 482 | 483 | self.log_dict({ 484 | 'eval_acc': acc, 485 | 'eval_cohen': cohen, 486 | 'eval_f1_macro': f1_macro, 487 | }, on_step=False, prog_bar=True, logger=True, on_epoch=True) 488 | 489 | def test_epoch_end(self, outputs): 490 | """This method collects the results and sorts the predictions according to record and sequence nr.""" 491 | 492 | try: 493 | all_records = sorted(self.trainer.datamodule.test.records) 494 | except AttributeError: # Catch exception if we've supplied dataloaders instead of DataModule 495 | all_records = sorted(self.trainer.test_dataloaders[0].dataset.records) 496 | 497 | true = torch.cat([out['true'] for out in outputs], dim=0).permute([0, 2, 1]) 498 | predicted = torch.cat([out['predicted'] for out in outputs], dim=0).permute([0, 2, 1]) 499 | stable_sleep = torch.cat([out['stable_sleep'].to(torch.int64) for out in outputs], dim=0) 500 | sequence_nrs = torch.cat([out['sequence_nr'] for out in outputs], dim=0) 501 | logits = torch.cat([out['logits'] for out in outputs], dim=0).permute([0, 2, 1]) 502 | 503 | if self.use_ddp: 504 | record2int = {r: idx for idx, r in enumerate(all_records)} 505 | int2record = {idx: r for idx, r in enumerate(all_records)} 506 | records = torch.cat([torch.Tensor([record2int[r]]).type_as(stable_sleep) for out in outputs for r in out['record']], dim=0) 507 | out_true = [torch.zeros_like(true) for _ in range(torch.distributed.get_world_size())] 508 | out_predicted = [torch.zeros_like(predicted) for _ in range(torch.distributed.get_world_size())] 509 | out_stable_sleep = [torch.zeros_like(stable_sleep) for _ in range(torch.distributed.get_world_size())] 510 | out_seq_nrs = [torch.zeros_like(sequence_nrs) for _ in range(dist.get_world_size())] 511 | out_records = [torch.zeros_like(records) for _ in range(dist.get_world_size())] 512 | out_logits = [torch.zeros_like(logits) for _ in range(dist.get_world_size())] 513 | 514 | dist.barrier() 515 | dist.all_gather(out_true, true) 516 | dist.all_gather(out_predicted, predicted) 517 | dist.all_gather(out_stable_sleep, stable_sleep) 518 | dist.all_gather(out_seq_nrs, sequence_nrs) 519 | dist.all_gather(out_records, records) 520 | dist.all_gather(out_logits, logits) 521 | 522 | if dist.get_rank() == 0: 523 | true = torch.cat(out_true) 524 | predicted = torch.cat(out_predicted) 525 | stable_sleep = torch.cat(out_stable_sleep) 526 | sequence_nrs = torch.cat(out_seq_nrs) 527 | records = [int2record[r.item()] for r in torch.cat(out_records)] 528 | logits = torch.cat(out_logits) 529 | # t = torch.stack(out_true).transpose(0, 1).reshape(-1, int(300/self.hparams.eval_frequency_sec), self.hparams.n_classes).cpu().numpy() 530 | # p = torch.stack(out_predicted).transpose(0, 1).reshape(-1, int(300/self.hparams.eval_frequency_sec), self.hparams.n_classes).cpu().numpy() 531 | # s = torch.stack(out_stable_sleep).transpose(0, 1).reshape(-1, int(300/self.hparams.eval_frequency_sec)).to(torch.bool).cpu().numpy() 532 | # u = t.sum(axis=-1) == 1 533 | 534 | else: 535 | return None 536 | else: 537 | records = [r for out in outputs for r in out['record']] 538 | 539 | # elif self.on_gpu: 540 | # t = true.cpu().numpy()[:, ::self.hparams.epoch_length, :] 541 | # p = predicted.cpu().numpy() 542 | # s = stable_sleep.to(torch.bool).cpu().numpy()[:, ::self.hparams.epoch_length] 543 | # u = t.sum(axis=-1) == 1 544 | # else: 545 | # t = true.numpy()[:, ::self.hparams.epoch_length, :] 546 | # p = predicted.numpy() 547 | # s = stable_sleep.to(torch.bool).numpy()[:, ::self.hparams.epoch_length, :] 548 | # u = t.sum(axis=-1) == 1 549 | 550 | results = { 551 | r: { 552 | "true": [], 553 | "true_label": [], 554 | "predicted": [], 555 | "predicted_label": [], 556 | "stable_sleep": [], 557 | "logits": [], 558 | "seq_nr": [], 559 | # 'acc': None, 560 | # 'f1': None, 561 | # 'recall': None, 562 | # 'precision': None, 563 | } for r in all_records 564 | } 565 | 566 | for r in tqdm(all_records, desc='Sorting predictions...'): 567 | record_idx = [idx for idx, rec in enumerate(records) if r == rec] 568 | current_t = torch.cat([t for idx, t in enumerate(true) if idx in record_idx], dim=0).cpu().numpy() 569 | current_p = torch.cat([p for idx, p in enumerate(predicted) if idx in record_idx], dim=0).cpu().numpy() 570 | current_ss = torch.cat([ss.to(torch.bool) for idx, ss in enumerate(stable_sleep) if idx in record_idx], dim=0).cpu().numpy() 571 | current_l = torch.cat([l for idx, l in enumerate(logits) if idx in record_idx], dim=0).cpu().numpy() 572 | current_seq = torch.stack([seq for idx, seq in enumerate(sequence_nrs) if idx in record_idx]).cpu().numpy() 573 | 574 | results[r]['true'] = current_t.reshape(-1, self.hparams.n_segments, self.hparams.num_classes)[current_seq.argsort()].reshape(-1, self.hparams.num_classes) 575 | results[r]['predicted'] = current_p.reshape(-1, self.hparams.n_segments, self.hparams.num_classes)[current_seq.argsort()].reshape(-1, self.hparams.num_classes) 576 | results[r]['stable_sleep'] = current_ss.reshape(-1, self.hparams.n_segments)[current_seq.argsort()].reshape(-1) 577 | results[r]['logits'] = current_l.reshape(-1, self.hparams.epoch_length * self.hparams.n_segments, self.hparams.num_classes)[current_seq.argsort()].reshape(-1, self.hparams.num_classes).shape 578 | results[r]['sequence'] = current_seq[current_seq.argsort()] 579 | 580 | return results 581 | 582 | def compute_loss(self, y_pred, y_true, stable_sleep): 583 | # stable_sleep = stable_sleep[:, ::self.hparams.epoch_length] 584 | # y_true = y_true[:, :, ::self.hparams.epoch_length] 585 | 586 | if y_pred.shape[-1] != self.hparams.num_classes: 587 | y_pred = y_pred.permute(dims=[0, 2, 1]) 588 | if y_true.shape[-1] != self.hparams.num_classes: 589 | y_true = y_true.permute(dims=[0, 2, 1]) 590 | # return self.loss(y_pred, y_true.argmax(dim=-1)) 591 | 592 | # return 593 | return self.loss(y_pred, y_true, stable_sleep) 594 | 595 | def configure_optimizers(self): 596 | return torch.optim.Adam( 597 | # [ 598 | # {'params': list(self.encoder.parameters())}, 599 | # {'params': list(self.decoder.parameters())}, 600 | # # {'params': [p[1] for p in self.named_parameters() if 'bias' not in p[0] and 'batch_norm' not in p[0]]}, 601 | # {'params': list(self.segment_classifier.parameters())[0], 'weight_decay': 1e-5}, 602 | # {'params': list(self.segment_classifier.parameters())[1]}, 603 | # ], 604 | self.parameters(), **self.optimizer_params 605 | ) 606 | 607 | # def on_after_backward(self): 608 | # print('Hej') 609 | 610 | # def train_dataloader(self): 611 | # """Return training dataloader.""" 612 | # return DataLoader(self.train_data, shuffle=True, **self.dataloader_params) 613 | 614 | # def val_dataloader(self): 615 | # """Return validation dataloader.""" 616 | # return DataLoader(self.eval_data, shuffle=False, **self.dataloader_params) 617 | 618 | # def setup(self, stage): 619 | # if stage == 'fit': 620 | # self.dataset = SscWscPsgDataset(**self.dataset_params) 621 | # self.train_data, self.eval_data = self.dataset.split_data(self.hparams.eval_ratio) 622 | 623 | @staticmethod 624 | def add_model_specific_args(parent_parser): 625 | 626 | # MODEL specific 627 | parser = ArgumentParser(parents=[parent_parser], add_help=True) 628 | architecture_group = parser.add_argument_group('architecture') 629 | architecture_group.add_argument('--filters', default=[16, 32, 64, 128], nargs='+', type=int) 630 | architecture_group.add_argument('--in_channels', default=5, type=int) 631 | architecture_group.add_argument('--maxpool_kernels', default=[10, 8, 6, 4], nargs='+', type=int) 632 | architecture_group.add_argument('--kernel_size', default=5, type=int) 633 | architecture_group.add_argument('--dilation', default=2, type=int) 634 | architecture_group.add_argument('--sampling_frequency', default=128, type=int) 635 | architecture_group.add_argument('--num_classes', default=5, type=int) 636 | architecture_group.add_argument('--epoch_length', default=30, type=int) 637 | architecture_group.add_argument('--n_segments', default=10, type=int) 638 | 639 | # OPTIMIZER specific 640 | optimizer_group = parser.add_argument_group('optimizer') 641 | # optimizer_group.add_argument('--optimizer', default='sgd', type=str) 642 | optimizer_group.add_argument('--lr', default=5e-6, type=float) 643 | # optimizer_group.add_argument('--momentum', default=0.9, type=float) 644 | # optimizer_group.add_argument('--weight_decay', default=0, type=float) 645 | 646 | # LEARNING RATE SCHEDULER specific 647 | # lr_scheduler_group = parser.add_argument_group('lr_scheduler') 648 | # lr_scheduler_group.add_argument('--lr_scheduler', default=None, type=str) 649 | # lr_scheduler_group.add_argument('--base_lr', default=0.05, type=float) 650 | # lr_scheduler_group.add_argument('--lr_reduce_factor', default=0.1, type=float) 651 | # lr_scheduler_group.add_argument('--lr_reduce_patience', default=5, type=int) 652 | # lr_scheduler_group.add_argument('--max_lr', default=0.15, type=float) 653 | # lr_scheduler_group.add_argument('--step_size_up', default=0.05, type=int) 654 | 655 | # DATASET specific 656 | # dataset_group = parser.add_argument_group('dataset') 657 | # dataset_group.add_argument('--data_dir', default='data/train/raw/individual_encodings', type=str) 658 | # dataset_group.add_argument('--eval_ratio', default=0.1, type=float) 659 | # dataset_group.add_argument('--n_jobs', default=-1, type=int) 660 | # dataset_group.add_argument('--n_records', default=-1, type=int) 661 | # dataset_group.add_argument('--scaling', default=None, type=str) 662 | 663 | 664 | # DATALOADER specific 665 | # dataloader_group = parser.add_argument_group('dataloader') 666 | # dataloader_group.add_argument('--batch_size', default=12, type=int) 667 | # dataloader_group.add_argument('--n_workers', default=0, type=int) 668 | 669 | return parser 670 | 671 | 672 | if __name__ == "__main__": 673 | from datasets import SscWscDataModule 674 | from pytorch_lightning.core.memory import ModelSummary 675 | 676 | parser = ArgumentParser(add_help=False) 677 | parser = SscWscDataModule.add_dataset_specific_args(parser) 678 | parser = UTimeModel.add_model_specific_args(parser) 679 | args = parser.parse_args() 680 | # parser.add_argument('--filters', default=[16, 32, 64, 128], nargs='+', type=int) 681 | # args = parser.parse_args() 682 | # print('Filters:', args.filters) 683 | args.in_channels = 4 684 | in_channels = args.in_channels 685 | x_shape = (1, in_channels, 10 * 30 * 128) 686 | x = torch.rand(x_shape) 687 | 688 | # # Test ConvBNReLU block 689 | # z = ConvBNReLU()(x) 690 | # print() 691 | # print(ConvBNReLU()) 692 | # print(x.shape) 693 | # print(z.shape) 694 | 695 | # # test Encoder class 696 | # encoder = Encoder() 697 | # print(encoder) 698 | # print("x.shape:", x.shape) 699 | # z, shortcuts = encoder(x) 700 | # print("z.shape:", z.shape) 701 | # print("Shortcuts shape:", [shortcut.shape for shortcut in shortcuts]) 702 | 703 | # # Test Decoder class 704 | # z_shape = (32, 256, 54) 705 | # z = torch.rand(z_shape) 706 | # decoder = Decoder() 707 | # print(decoder) 708 | # x_hat = decoder(z, None) 709 | # print("x_hat.shape:", x_hat.shape) 710 | 711 | # Test UTimeModel Class 712 | # utime = UTimeModel(in_channels=in_channels) 713 | utime = UTimeModel(vars(args)) 714 | utime.example_input_array = torch.zeros(x_shape) 715 | utime.configure_optimizers() 716 | model_summary = ModelSummary(utime, "full") 717 | print(model_summary) 718 | print(utime) 719 | print(x.shape) 720 | # z = utime(x) 721 | z = utime.classify_segments(x) 722 | print(z.shape) 723 | print("x.shape:", x.shape) 724 | print("z.shape:", z.shape) 725 | print(z.sum(dim=1)) 726 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | from pytorch_lightning import seed_everything 7 | from pytorch_lightning import Trainer 8 | 9 | import utils 10 | 11 | 12 | torch.backends.cudnn.benchmark = True 13 | 14 | 15 | def run_training(): 16 | 17 | args = utils.get_args() 18 | 19 | # Remember to seed! 20 | seed_everything(args.seed) 21 | 22 | # Setup data module for training 23 | dm, args = utils.get_data(args) 24 | 25 | # Setup model 26 | model = utils.get_model(args) 27 | 28 | # Setup callbacks 29 | loggers, callbacks = utils.get_loggers_callbacks(args, model) 30 | 31 | # Define trainer object from arguments 32 | trainer = Trainer.from_argparse_args(args, deterministic=True, logger=loggers, callbacks=callbacks) 33 | 34 | # ================================================================================================================ 35 | # LEARNING RATE FINDER ROUTINE 36 | # ---------------------------------------------------------------------------------------------------------------- 37 | if args.lr_finder: 38 | lr_finder = trainer.tuner.lr_find(model, datamodule=dm, min_lr=1e-7, max_lr=5e-5) 39 | fig = lr_finder.plot(suggest=True) 40 | fig.savefig("results/lr_finder/lr_range_test_bs32.png") 41 | return 42 | # ================================================================================================================ 43 | 44 | # Fit model using trainer 45 | trainer.fit(model, dm) 46 | 47 | # return 0 48 | 49 | # Return results on eval data 50 | predictions = trainer.test( 51 | model, 52 | # test_dataloaders=dm.val_dataloader(), 53 | test_dataloaders=dm.train_dataloader(), 54 | # ckpt_path=trainer.checkpoint_callback.best_model_path, 55 | verbose=False, 56 | ) 57 | if not model.use_ddp or (model.use_ddp and torch.distributed.get_rank() == 0): 58 | predictions = predictions[0] 59 | 60 | try: 61 | save_dir = args.save_dir 62 | except AttributeError: 63 | save_dir = os.path.dirname(args.resume_from_checkpoint) 64 | 65 | with open(os.path.join(save_dir, f"train_predictions.pkl"), "wb") as pkl: 66 | pickle.dump(predictions, pkl) 67 | 68 | eval_windows = [1] 69 | df, cm_sub, cm_tot = utils.evaluate_performance(predictions, evaluation_windows=eval_windows, cases=["all"]) 70 | best_acc = [] 71 | best_kappa = [] 72 | with np.printoptions(precision=3, suppress=True): 73 | s = "" 74 | for eval_window in cm_tot.keys(): 75 | # print() 76 | s += "\n" 77 | s += f"Evaluation window - {eval_window} s\n" 78 | s += "---------------------------------\n" 79 | for case in cm_tot[eval_window].keys(): 80 | df_ = df.query(f'Window == "{eval_window} s" and Case == "{case}"') 81 | s += f"Case: {case}\n" 82 | s += f"{cm_tot[eval_window][case]}\n" 83 | NP = cm_tot[eval_window][case].sum(axis=1) 84 | PP = cm_tot[eval_window][case].sum(axis=0) 85 | N = cm_tot[eval_window][case].sum() 86 | precision = np.diag(cm_tot[eval_window][case]) / (PP + 1e-10) 87 | recall = np.diag(cm_tot[eval_window][case]) / (NP + 1e-10) 88 | f1 = 2 * precision * recall / (precision + recall + 1e-10) 89 | acc = np.diag(cm_tot[eval_window][case]).sum() / N 90 | 91 | pe = N ** (-2) * (NP @ PP) 92 | kappa = 1 - (1 - acc) / (1 - pe) 93 | 94 | c = np.diag(cm_tot[eval_window][case]).sum() 95 | mcc = (c * N - NP @ PP) / (np.sqrt(N ** 2 - (PP @ PP)) * np.sqrt(N ** 2 - (NP @ NP))) 96 | 97 | s += "\n" 98 | s += f'Precision:\t{df_["Precision"].mean():.3f} +/- {df_["Precision"].std():.3f} \t|\t{precision}\n' 99 | s += f'Recall:\t\t{df_["Recall"].mean():.3f} +/- {df_["Recall"].std():.3f} \t|\t{recall}\n' 100 | s += f'F1: \t\t{df_["F1"].mean():.3f} +/- {df_["F1"].std():.3f} \t|\t{f1}\n' 101 | s += f'Accuracy:\t{df_["Accuracy"].mean():.3f} +/- {df_["Accuracy"].std():.3f} \t|\t{acc:.3f}\n' 102 | s += f'Kappa:\t\t{df_["Kappa"].mean():.3f} +/- {df_["Kappa"].std():.3f} \t|\t{kappa:.3f}\n' 103 | s += f'MCC:\t\t{df_["MCC"].mean():.3f} +/- {df_["MCC"].std():.3f} \t|\t{mcc:.3f}\n' 104 | s += "\n" 105 | 106 | best_acc.append(acc) 107 | best_kappa.append(kappa) 108 | print(s) 109 | with open(os.path.join(save_dir, "train_case_results.txt"), "w") as txt_file: 110 | print(s, file=txt_file) 111 | df.to_csv(os.path.join(save_dir, f"train_results.csv")) 112 | with open(os.path.join(save_dir, f"train_confusionmatrix.pkl"), "wb") as pkl: 113 | pickle.dump({"confusiomatrix_subject": cm_sub, "confusionmatrix_total": cm_tot}, pkl) 114 | 115 | try: 116 | trainer.logger.experiment.summary["best_acc"] = best_acc 117 | trainer.logger.experiment.summary["best_kappa"] = best_kappa 118 | except AttributeError: 119 | trainer.logger.experiment[1].summary["best_acc"] = best_acc 120 | trainer.logger.experiment[1].summary["best_kappa"] = best_kappa 121 | 122 | # # Run predictions on test data 123 | # if args.model_type == "stages": 124 | # results_dir = os.path.join( 125 | # "results", 126 | # args.model_type, 127 | # args.model_name, 128 | # args.resume_from_checkpoint.split("/")[2], 129 | # os.path.basename(wandb_logger.save_dir), 130 | # ) 131 | # elif args.model_type == "massc": 132 | # results_dir = Path(os.path.join(args.save_dir, "results")) 133 | # results_dir.mkdir(parents=True, exist_ok=True) 134 | # # results_dir = os.path.join( 135 | # # "results", args.model_type, args.resume_from_checkpoint.split("/")[2], os.path.basename(wandb_logger.save_dir), 136 | # # ) 137 | # # if not os.path.exists(results_dir): 138 | # # os.makedirs(results_dir) 139 | # # test_data = DataLoader(SscWscPsgDataset("./data/test/raw/ssc_wsc"), num_workers=args.n_workers, pin_memory=True) 140 | # # results = trainer.test(test_dataloaders=test_data, verbose=False)[0] 141 | # # evaluate_performance(results) 142 | # # print(len(results.keys())) 143 | # # with open(os.path.join(results_dir, 'SSC_WSC.pkl'), 'wb') as pkl: 144 | # # pickle.dump(results, pkl) 145 | 146 | # # KHC data 147 | # khc_data = DataLoader(KoreanDataset(), num_workers=args.n_workers, pin_memory=True) 148 | # results = trainer.test(test_dataloaders=khc_data, verbose=False) 149 | # df = evaluate_performance(results) 150 | # print(len(results.keys())) 151 | # with open(os.path.join(results_dir, "KHC.pkl"), "wb") as pkl: 152 | # pickle.dump(results, pkl) 153 | 154 | # df.to_csv(os.path.join(results_dir, 'KHC.csv')) 155 | 156 | # results = trainer.test(verbose=False) 157 | # test_params = dict(num_workers=args.n_workers, pin_memory=True) 158 | 159 | # test_data = DataLoader(datasets.SscWscPsgDataset(data_dir=args.data_dir, overlap=False, n_records=20, scaling="robust")) 160 | # test_data = DataLoader(SscWscPsgDataset("./data/test/raw/ssc_wsc", overlap=False, n_records=10), **test_params) 161 | # run_testing(test_data, "SSC-WSC") 162 | 163 | return 0 164 | 165 | 166 | if __name__ == "__main__": 167 | run_training() 168 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .arg_utils import get_args 2 | from .dataset_utils import get_data 3 | from .evaluate_performance import evaluate_performance 4 | from .h5_utils import get_h5_info 5 | from .h5_utils import load_h5_data 6 | from .h5_utils import load_h5_data 7 | from .logger_callback_utils import get_loggers_callbacks 8 | from .losses import DiceLoss 9 | from .model_utils import get_model 10 | from .parallel_bar import ParallelExecutor 11 | -------------------------------------------------------------------------------- /utils/arg_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pprint 4 | from datetime import datetime 5 | from glob import glob 6 | 7 | import torch 8 | import pytorch_lightning as pl 9 | 10 | import datasets 11 | import models 12 | 13 | 14 | def get_args(): 15 | 16 | parser = argparse.ArgumentParser(add_help=False) 17 | parser.add_argument("--name", default=None, type=str) 18 | parser.add_argument("--debug", action="store_true") 19 | parser.add_argument("--lr_finder", action="store_true") 20 | parser.add_argument("--seed", default=1337, type=int) 21 | parser.add_argument("--dataset_type", type=str, default="ssc-wsc") 22 | parser.add_argument("--checkpoint_monitor", default="eval_loss", type=str) 23 | parser.add_argument("--earlystopping_monitor", default="eval_loss", type=str) 24 | parser.add_argument("--earlystopping_patience", default=100, type=int) 25 | 26 | # add args from trainer 27 | parser = pl.Trainer.add_argparse_args(parser) 28 | 29 | # Check the supplied model type 30 | parser.add_argument("--model_type", type=str, default="utime") 31 | temp_args, _ = parser.parse_known_args() 32 | 33 | # Optionally resume from checkpoint 34 | if temp_args.resume_from_checkpoint and os.path.isdir(temp_args.resume_from_checkpoint): 35 | temp_args.resume_from_checkpoint = glob(os.path.join(temp_args.resume_from_checkpoint, "epoch*.ckpt"))[0] 36 | if temp_args.resume_from_checkpoint: 37 | hparams = torch.load(temp_args.resume_from_checkpoint, map_location=torch.device("cpu"))["hyper_parameters"] 38 | temp_args.model_type = hparams["model_type"] 39 | 40 | # add args from dataset 41 | parser = datasets.available_datasets["ssc-wsc"].add_dataset_specific_args(parser) 42 | 43 | # add args from model 44 | parser = models.available_models[temp_args.model_type].add_model_specific_args(parser) 45 | 46 | # parse params 47 | args = parser.parse_args() 48 | 49 | # update args from hparams 50 | if args.resume_from_checkpoint: 51 | args.model_type = hparams["model_type"] 52 | 53 | # Create a save directory 54 | if not args.resume_from_checkpoint: 55 | try: 56 | args.save_dir = os.path.join( 57 | "experiments", args.model_type, args.model_name, args.name, datetime.now().strftime("%Y%m%d_%H%M%S"), 58 | ) 59 | except AttributeError: 60 | args.save_dir = os.path.join("experiments", "utime", datetime.now().strftime("%Y%m%d_%H%M%S")) 61 | 62 | # Get the best model from the directory by default 63 | if args.resume_from_checkpoint and os.path.isdir(args.resume_from_checkpoint): 64 | args.resume_from_checkpoint = glob(os.path.join(args.resume_from_checkpoint, "epoch*.ckpt"))[0] 65 | 66 | # If you wish to view applied settings, uncomment these two lines. 67 | if args.debug: 68 | pprint.pprint(vars(args)) 69 | 70 | return args 71 | -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | 3 | 4 | def get_data(args): 5 | 6 | dm = datasets.available_datasets[args.dataset_type](**vars(args)) 7 | dm.setup() 8 | 9 | try: 10 | args.cb_weights = dm.train.dataset.cb_weights 11 | except AttributeError: 12 | pass 13 | 14 | return dm, args 15 | -------------------------------------------------------------------------------- /utils/evaluate_performance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.metrics import ( 4 | accuracy_score, 5 | balanced_accuracy_score, 6 | cohen_kappa_score, 7 | confusion_matrix, 8 | f1_score, 9 | matthews_corrcoef, 10 | precision_score, 11 | recall_score, 12 | precision_recall_fscore_support, 13 | ) 14 | from tqdm import tqdm 15 | 16 | 17 | def transition_matrix(y, n_classes=5, eval_frequency=[30]): 18 | 19 | # Generate count matrix 20 | M = np.zeros((n_classes, n_classes)) 21 | for i, j in zip(y, y[1:]): 22 | M[i, j] += 1 23 | 24 | # Convert to probabilites 25 | M /= M.sum(axis=1, keepdims=True) 26 | 27 | return M 28 | 29 | 30 | # def get_transitions_matrices(record_predictions, eval_frequency=[30]): 31 | 32 | 33 | # fmt: off 34 | def evaluate_performance(record_predictions, evaluation_windows=[1, 3, 5, 10, 15, 30], cases=['all', 'stable', 'transition']): 35 | """Evaluate the performance of the predicted results. 36 | 37 | Args: 38 | record_predictions (dict): dict containing predicted and true labels for every record 39 | """ 40 | records = [r for r in record_predictions.keys()] 41 | ids = [r.split(".")[0] for r in records] 42 | df_total = [] 43 | confmat_subject = {record: {eval_window: {case: None for case in cases} for eval_window in evaluation_windows} for record in records} 44 | confmat_total = {eval_window: {case: np.zeros((5, 5)) for case in cases} for eval_window in evaluation_windows} 45 | for eval_window in evaluation_windows: 46 | for case in cases: 47 | df = pd.DataFrame() 48 | df["FileID"] = records 49 | df["SubjectID"] = ids 50 | df["Window"] = f"{eval_window} s" 51 | df["Case"] = case 52 | print('') 53 | print(f'Evaluation window: {eval_window} | Case: {case}') 54 | for idx, record in enumerate(tqdm(records)): 55 | not_unknown_stage = record_predictions[record]['true'].sum(axis=1) == 1 56 | 57 | if case == 'all': 58 | extract = np.full(record_predictions[record]['true'].shape[0], True) & not_unknown_stage 59 | elif case == 'stable': 60 | extract = record_predictions[record]['stable_sleep'] & not_unknown_stage 61 | elif case == 'transition': 62 | not_unknown_stage = record_predictions[record]['true'].sum(axis=1) == 1 63 | extract = np.invert(record_predictions[record]['stable_sleep']) & not_unknown_stage 64 | 65 | # Get the true and predicted stages 66 | t = record_predictions[record]["true"][extract, :].argmax(axis=1)[::eval_window] 67 | p = np.mean(record_predictions[record]["predicted"][extract, :].reshape(-1, eval_window, 5), axis=1).argmax(axis=1) 68 | 69 | # Extract the metrics 70 | acc = accuracy_score(t, p) 71 | bal_acc = balanced_accuracy_score(t, p) 72 | kappa = cohen_kappa_score(t, p) 73 | f1 = f1_score(t, p, average="macro") 74 | prec = precision_score(t, p, average="macro") 75 | recall = recall_score(t, p, average="macro") 76 | mcc = matthews_corrcoef(t, p) 77 | 78 | # Assign metrics to dataframe 79 | df.loc[idx, "Accuracy"] = acc 80 | df.loc[idx, "Balanced accuracy"] = bal_acc 81 | df.loc[idx, "Kappa"] = kappa 82 | df.loc[idx, "F1"] = f1 83 | df.loc[idx, "Precision"] = prec 84 | df.loc[idx, "Recall"] = recall 85 | df.loc[idx, "MCC"] = mcc 86 | 87 | # Get stage-specific metrics 88 | precision, recall, f1, support = precision_recall_fscore_support(t, p, labels=[0, 1, 2, 3, 4]) 89 | 90 | # Assign to dataframe 91 | for stage_idx, stage in zip([0, 1, 2, 3, 4], ["W", "N1", "N2", "N3", "REM"]): 92 | df.loc[idx, f"F1 - {stage}"] = f1[stage_idx] 93 | df.loc[idx, f"Precision - {stage}"] = precision[stage_idx] 94 | df.loc[idx, f"Recall - {stage}"] = recall[stage_idx] 95 | df.loc[idx, f"Support - {stage}"] = support[stage_idx] 96 | 97 | # Get confusion matrix 98 | C = confusion_matrix(t, p, labels=[0, 1, 2, 3, 4]) 99 | confmat_subject[record][eval_window][case] = C 100 | confmat_total[eval_window][case] += C 101 | 102 | # Update list 103 | df_total.append(df) 104 | 105 | # Finalize dataframe 106 | df_total = pd.concat(df_total) 107 | 108 | return df_total, confmat_subject, confmat_total 109 | # fmt: on 110 | 111 | # metrics = { 112 | # "record": records, 113 | # "id": ids, 114 | # "macro_f1": [], 115 | # "micro_f1": [], 116 | # "accuracy": [], 117 | # "balanced_accuracy": [], 118 | # "kappa": [], 119 | # "mcc": [], 120 | # "macro_recall": [], 121 | # "micro_recall": [], 122 | # "macro_precision": [], 123 | # "micro_precision": [], 124 | # } 125 | # for record in records: 126 | 127 | # y_true = record_predictions[record]["true_label"] 128 | # y_pred = record_predictions[record]["predicted_label"] 129 | # # labels = [0, 1, 2, 3, 4] 130 | 131 | # metrics["macro_f1"].append(f1_score(y_true, y_pred, average="macro")) 132 | # metrics["micro_f1"].append(f1_score(y_true, y_pred, average="micro")) 133 | # metrics["accuracy"].append(accuracy_score(y_true, y_pred)) 134 | # metrics["balanced_accuracy"].append(balanced_accuracy_score(y_true, y_pred)) 135 | # metrics["kappa"].append(cohen_kappa_score(y_true, y_pred)) 136 | # metrics["mcc"].append(matthews_corrcoef(y_true, y_pred)) 137 | # metrics["macro_recall"].append(recall_score(y_true, y_pred, average="macro")) 138 | # metrics["micro_recall"].append(recall_score(y_true, y_pred, average="micro")) 139 | # metrics["macro_precision"].append(precision_score(y_true, y_pred, average="macro")) 140 | # metrics["micro_precision"].append(precision_score(y_true, y_pred, average="micro")) 141 | # # metrics["macro_f1"].append(f1_score(y_true, y_pred, labels=labels, average="macro")) 142 | # # metrics["micro_f1"].append(f1_score(y_true, y_pred, labels=labels, average="micro")) 143 | # # metrics["accuracy"].append(accuracy_score(y_true, y_pred)) 144 | # # metrics["balanced_accuracy"].append(balanced_accuracy_score(y_true, y_pred)) 145 | # # metrics["kappa"].append(cohen_kappa_score(y_true, y_pred, labels=labels)) 146 | # # metrics["mcc"].append(matthews_corrcoef(y_true, y_pred)) 147 | # # metrics["macro_recall"].append(recall_score(y_true, y_pred, labels=labels, average="macro")) 148 | # # metrics["micro_recall"].append(recall_score(y_true, y_pred, labels=labels, average="micro")) 149 | # # metrics["macro_precision"].append(precision_score(y_true, y_pred, labels=labels, average="macro")) 150 | # # metrics["micro_precision"].append(precision_score(y_true, y_pred, labels=labels, average="micro")) 151 | # total_acc = [] 152 | 153 | # return pd.DataFrame.from_dict(metrics).set_index("record") 154 | -------------------------------------------------------------------------------- /utils/h5_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from h5py import File 3 | from sklearn import preprocessing 4 | 5 | 6 | SCALERS = {"robust": preprocessing.RobustScaler, "standard": preprocessing.StandardScaler} 7 | 8 | 9 | def get_h5_info(filename): 10 | 11 | with File(filename, "r") as h5: 12 | dataT = h5["trainD"] 13 | seqs_in_file = dataT.shape[0] 14 | 15 | return seqs_in_file 16 | 17 | 18 | def load_h5_data(filename, seg_size): 19 | 20 | with File(filename, "r") as h5: 21 | # print(h5.keys()) 22 | dataT = h5["trainD"][:].astype("float32") 23 | targetT = h5["trainL"][:].astype("float32") 24 | weights = h5["trainW"][:].astype("float32") 25 | # dataT = h5["trainD"] 26 | # print("hej") 27 | 28 | # Hack to make sure axis order is preserved 29 | if dataT.shape[-1] == 300: 30 | dataT = np.swapaxes(dataT, 0, 2) 31 | targetT = np.swapaxes(targetT, 0, 2) 32 | weights = weights.T 33 | 34 | # print(dataT.shape) 35 | # print(targetT.shape) 36 | # print(weights.shape) 37 | # print(f'{filename} loaded - Training') 38 | 39 | seq_in_file = dataT.shape[0] 40 | # n_segs = dataT.shape[1] // seg_size 41 | n_segs = dataT.shape[-1] // seg_size 42 | 43 | # return ( 44 | # np.reshape(dataT, [seq_in_file, n_segs, seg_size, -1]), 45 | # np.reshape(targetT, [seq_in_file, n_segs, seg_size, -1]), 46 | # np.reshape(weights, [seq_in_file, n_segs, seg_size]), 47 | # seq_in_file, 48 | # ) 49 | return ( 50 | np.reshape(dataT, [seq_in_file, -1, n_segs, seg_size]), 51 | np.reshape(targetT, [seq_in_file, -1, n_segs, seg_size]), 52 | np.reshape(weights, [seq_in_file, n_segs, seg_size]), 53 | seq_in_file, 54 | ) 55 | 56 | 57 | def load_psg_h5_data(filename, scaling=None): 58 | scaler = None 59 | 60 | if scaling: 61 | scaler = SCALERS[scaling]() 62 | 63 | with File(filename, "r") as h5: 64 | N, C, T = h5["M"].shape 65 | sequences_in_file = N 66 | 67 | if scaling: 68 | scaler.fit(h5["M"][:].transpose(1, 0, 2).reshape((C, N * T)).T) 69 | 70 | return sequences_in_file, scaler 71 | -------------------------------------------------------------------------------- /utils/logger_callback_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from pytorch_lightning import loggers as pl_loggers 4 | from pytorch_lightning import callbacks as pl_callbacks 5 | 6 | 7 | def get_loggers_callbacks(args, model=None): 8 | 9 | try: 10 | # Setup logger(s) params 11 | csv_logger_params = dict( 12 | save_dir="./experiments", 13 | name=os.path.join(*args.save_dir.split("/")[1:-1]), 14 | version=args.save_dir.split("/")[-1], 15 | ) 16 | wandb_logger_params = dict( 17 | log_model=False, 18 | name=os.path.join(*args.save_dir.split("/")[1:]), 19 | offline=args.debug, 20 | project="utime", 21 | save_dir=args.save_dir, 22 | ) 23 | loggers = [ 24 | pl_loggers.CSVLogger(**csv_logger_params), 25 | pl_loggers.WandbLogger(**wandb_logger_params), 26 | ] 27 | if model: 28 | loggers[-1].watch(model) 29 | 30 | # Setup callback(s) params 31 | checkpoint_monitor_params = dict( 32 | filepath=os.path.join(args.save_dir, "{epoch:03d}-{eval_loss:.2f}"), 33 | monitor=args.checkpoint_monitor, 34 | save_last=True, 35 | save_top_k=1, 36 | ) 37 | earlystopping_parameters = dict(monitor=args.earlystopping_monitor, patience=args.earlystopping_patience,) 38 | callbacks = [ 39 | pl_callbacks.ModelCheckpoint(**checkpoint_monitor_params), 40 | pl_callbacks.EarlyStopping(**earlystopping_parameters), 41 | pl_callbacks.LearningRateMonitor(), 42 | ] 43 | 44 | return loggers, callbacks 45 | except AttributeError: 46 | return None, None 47 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DiceLoss(nn.Module): 6 | def __init__(self, num_classes, smooth=1): 7 | super().__init__() 8 | self.num_classes = num_classes 9 | self.smooth = smooth 10 | 11 | def forward(self, pred, target, mask=None): 12 | 13 | assert ( 14 | pred.shape == target.shape 15 | ), f"Target shape: {target.shape} does not match predicted shape: {pred.shape}!" 16 | 17 | if mask is None: 18 | mask = torch.full(pred.shape, True).type_as(pred) 19 | elif mask.shape[-1] != self.num_classes: 20 | mask = mask.unsqueeze(-1).repeat(1, 1, self.num_classes) 21 | 22 | pred = pred * mask.float().int() 23 | target = target * mask.float().int() 24 | 25 | reduction_dims = list(range(len(pred.shape))[1:-1]) 26 | 27 | intersection = torch.sum(pred * target, dim=reduction_dims) 28 | union = torch.sum(pred + target, dim=reduction_dims) 29 | dice = (2 * intersection + self.smooth) / (union + self.smooth) 30 | return (1 - torch.mean(dice, dim=-1)).mean() 31 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | 5 | import models 6 | 7 | 8 | def get_model(args): 9 | 10 | if args.resume_from_checkpoint: 11 | # resume_from_checkpoint = args.resume_from_checkpoint 12 | # args = Namespace( 13 | # **(torch.load(args.resume_from_checkpoint, map_location=torch.device("cpu"))["hyper_parameters"]) 14 | # ) 15 | # args.resume_from_checkpoint = resume_from_checkpoint 16 | model = models.available_models[args.model_type].load_from_checkpoint(args.resume_from_checkpoint) 17 | else: 18 | model = models.available_models[args.model_type](**vars(args)) 19 | 20 | return model 21 | -------------------------------------------------------------------------------- /utils/parallel_bar.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from joblib import Parallel 4 | from tqdm import tqdm 5 | 6 | 7 | def text_progressbar(seq, total=None): 8 | step = 1 9 | tick = time.time() 10 | while True: 11 | time_diff = time.time() - tick 12 | avg_speed = time_diff / step 13 | total_str = f"of {total}" if total else "" 14 | print("step", step, "%.2f" % time_diff, "avg: %.2f iter/sec" % avg_speed, total_str) 15 | step += 1 16 | yield next(seq) 17 | 18 | 19 | all_bar_funcs = { 20 | "tqdm": lambda args: lambda x: tqdm(x, **args), 21 | "txt": lambda args: lambda x: text_progressbar(x, **args), 22 | "False": lambda args: iter, 23 | "None": lambda args: iter, 24 | } 25 | 26 | 27 | def ParallelExecutor(use_bar="tqdm", **joblib_args): 28 | def aprun(bar=use_bar, **tq_args): 29 | def tmp(op_iter): 30 | if str(bar) in all_bar_funcs.keys(): 31 | bar_func = all_bar_funcs[str(bar)](tq_args) 32 | else: 33 | raise ValueError("Value %s not supported as bar type" % bar) 34 | return Parallel(**joblib_args)(bar_func(op_iter)) 35 | 36 | return tmp 37 | 38 | return aprun 39 | --------------------------------------------------------------------------------