├── .DS_Store ├── Kfold_trainer.py ├── README.md ├── args.py ├── data_loader.py ├── data_preprocess_TF.py ├── dataset_prepare.py ├── early_stop_tool.py ├── imgs ├── .1-sunloginclient6963C79D-1100-4A1C-90C8-85045AAA29C5 ├── .3-sunloginclient85CB63AE-0D21-4D56-9D9B-6C9C33E079F1 ├── .DS_Store └── MultiChannelSleepNet.png ├── model.py └── result_evaluate.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangdai97/MultiChannelSleepNet/290daf0f5ec75f06a9a60b524f39bd22fb9725d7/.DS_Store -------------------------------------------------------------------------------- /Kfold_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | import torch 6 | from torch import nn 7 | from torch import optim 8 | from torch.autograd import Variable 9 | from torch.utils.data import TensorDataset, DataLoader 10 | 11 | from sklearn.model_selection import StratifiedKFold 12 | from sklearn.metrics import accuracy_score 13 | 14 | from model import Transformer 15 | from early_stop_tool import EarlyStopping 16 | from data_loader import data_generator 17 | from args import Config, Path 18 | 19 | 20 | def set_random_seed(seed=0): 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) # CPU 23 | torch.cuda.manual_seed(seed) # GPU 24 | 25 | 26 | def test(model, test_loader, config): 27 | criterion = nn.CrossEntropyLoss() 28 | model.eval() 29 | 30 | pred = [] 31 | label = [] 32 | 33 | test_loss = 0 34 | 35 | with torch.no_grad(): 36 | for batch_idx, (data, target) in enumerate(test_loader): 37 | data = data.to(config.device) 38 | target = target.to(config.device) 39 | data, target = Variable(data), Variable(target) 40 | 41 | output = model(data) 42 | test_loss += criterion(output, target.long()).item() 43 | 44 | pred.extend(np.argmax(output.data.cpu().numpy(), axis=1)) 45 | label.extend(target.data.cpu().numpy()) 46 | 47 | accuracy = accuracy_score(label, pred, normalize=True, sample_weight=None) 48 | 49 | return accuracy, test_loss 50 | 51 | 52 | def train(save_all_checkpoint=False): 53 | config = Config() 54 | path = Path() 55 | 56 | dataset, labels, val_loader = data_generator(path_labels=path.path_labels, path_dataset=path.path_TF) 57 | 58 | kf = StratifiedKFold(n_splits=config.num_fold, shuffle=True, random_state=0) 59 | 60 | for fold, (train_idx, test_idx) in enumerate(kf.split(dataset, labels)): 61 | print('\n', '-' * 15, '>', f'Fold {fold}', '<', '-' * 15) 62 | if not os.path.exists('./Kfold_models/fold{}'.format(fold)): 63 | os.makedirs('./Kfold_models/fold{}'.format(fold)) 64 | 65 | X_train, X_test = dataset[train_idx], dataset[test_idx] 66 | y_train, y_test = labels[train_idx], labels[test_idx] 67 | train_set = TensorDataset(X_train, y_train) 68 | test_set = TensorDataset(X_test, y_test) 69 | train_loader = DataLoader(dataset=train_set, batch_size=config.batch_size, shuffle=False) 70 | test_loader = DataLoader(dataset=test_set, batch_size=config.batch_size, shuffle=False) 71 | 72 | model = Transformer(config) 73 | model = model.to(config.device) 74 | 75 | criterion = nn.CrossEntropyLoss() 76 | 77 | # AdamW optimizer 78 | optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0.01) 79 | 80 | # apply early_stop. If you want to view the full training process, set the save_all_checkpoint True 81 | early_stopping = EarlyStopping(patience=20, verbose=True, save_all_checkpoint=save_all_checkpoint) 82 | 83 | # evaluating indicator 84 | train_ACC = [] 85 | train_LOSS = [] 86 | test_ACC = [] 87 | test_LOSS = [] 88 | val_ACC = [] 89 | val_LOSS = [] 90 | 91 | for epoch in range(config.num_epochs): 92 | running_loss = 0.0 93 | correct = 0 94 | 95 | model.train() 96 | 97 | loop = tqdm(enumerate(train_loader), total=len(train_loader)) 98 | for batch_idx, (data, target) in loop: 99 | data = data.to(config.device) 100 | target = target.to(config.device) 101 | data, target = Variable(data), Variable(target) 102 | 103 | optimizer.zero_grad() 104 | output = model(data) 105 | 106 | loss = criterion(output, target.long()) 107 | 108 | loss.backward() 109 | 110 | optimizer.step() 111 | 112 | running_loss += loss.item() 113 | 114 | train_acc_batch = np.sum(np.argmax(np.array(output.data.cpu()), axis=1) == np.array(target.data.cpu())) / (target.shape[0]) 115 | loop.set_postfix(train_acc=train_acc_batch, loss=loss.item()) 116 | correct += np.sum(np.argmax(np.array(output.data.cpu()), axis=1) == np.array(target.data.cpu())) 117 | 118 | train_acc = correct / len(train_loader.dataset) 119 | test_acc, test_loss = test(model, test_loader, config) 120 | val_acc, val_loss = test(model, val_loader, config) 121 | print('Epoch: ', epoch, 122 | '| train loss: %.4f' % running_loss, '| train acc: %.4f' % train_acc, 123 | '| val acc: %.4f' % val_acc, '| val loss: %.4f' % val_loss, 124 | '| test acc: %.4f' % test_acc, '| test loss: %.4f' % test_loss) 125 | 126 | train_ACC.append(train_acc) 127 | train_LOSS.append(running_loss) 128 | test_ACC.append(test_acc) 129 | test_LOSS.append(test_loss) 130 | val_ACC.append(val_acc) 131 | val_LOSS.append(val_loss) 132 | 133 | # Check whether to continue training. If save_all_checkpoint=False, the model name will be ‘model.pkl' 134 | early_stopping(val_acc, model, path='./Kfold_models/fold{}/model_{}_epoch{}.pkl'.format(fold, fold, epoch)) 135 | 136 | if early_stopping.early_stop: 137 | print("Early stopping at epoch ", epoch) 138 | break 139 | 140 | np.save('./Kfold_models/fold{}/train_LOSS.npy'.format(fold), np.array(train_LOSS)) 141 | np.save('./Kfold_models/fold{}/train_ACC.npy'.format(fold), np.array(train_ACC)) 142 | np.save('./Kfold_models/fold{}/test_LOSS.npy'.format(fold), np.array(test_LOSS)) 143 | np.save('./Kfold_models/fold{}/test_ACC.npy'.format(fold), np.array(test_ACC)) 144 | np.save('./Kfold_models/fold{}/val_LOSS.npy'.format(fold), np.array(val_LOSS)) 145 | np.save('./Kfold_models/fold{}/val_ACC.npy'.format(fold), np.array(val_ACC)) 146 | 147 | del model 148 | 149 | 150 | if __name__ == '__main__': 151 | set_random_seed(0) 152 | train(save_all_checkpoint=False) 153 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiChannelSleepNet 2 | ### MultiChannelSleepNet: A Transformer-based Model for Automatic Sleep Stage Classification with PSG 3 | #### *by: Yang Dai, Xiuli li, Shanshan Liang, Lukang Wang, Qingtian Duan, Hui Yang, Chunqing Zhang, Xiaowei Chen, Longhui Li\*, Xingyi Li\*, and Xiang Liao\* 4 | This work has been accepted for publication in [IEEE Journal of Biomedical and Health Informatics (JBHI).](https://ieeexplore.ieee.org/document/10146380) 5 | 6 | ## Abstract 7 | ![AttnSleep Architecture](imgs/MultiChannelSleepNet.png) 8 | Automatic sleep stage classification plays an essential role in sleep quality measurement and sleep disorder diagnosis. Although many approaches have been developed, most use only single-channel electroencephalogram signals for classification. Polysomnography (PSG) provides multiple channels of signal recording, enabling the use of the appropriate method to extract and integrate the information from different channels to achieve higher sleep staging performance. We present a transformer encoder-based model, MultiChannelSleepNet, for automatic sleep stage classification with multichannel PSG data, whose architecture is implemented based on the transformer encoder for single-channel feature extraction and multichannel feature fusion. In a single-channel feature extraction block, transformer encoders extract features from time-frequency images of each channel independently. Based on our integration strategy, the feature maps extracted from each channel are fused in the multichannel feature fusion block. Another set of transformer encoders further capture joint features, and a residual connection preserves the original information from each channel in this block. Experimental results on three publicly available datasets demonstrate that our method achieves higher classification performance than state-of-the-art techniques. MultiChannelSleepNet is an efficient method to extract and integrate the information from multichannel PSG data, which facilitates precision sleep staging in clinical applications. 9 | 10 | 11 | ## Requirements: 12 | - python3.6 13 | - pytorch=='1.9.1' 14 | - numpy 15 | - sklearn 16 | - scipy=='1.5.4' 17 | - mne=='0.23.4' 18 | - tqdm 19 | 20 | ## Data 21 | We used three public datasets in this study: 22 | 23 | - SleepEDF-20 (2013 version) 24 | - [SleepEDF-78](https://physionet.org/content/sleep-edfx/1.0.0/) (2018 version) 25 | - [SHHS](https://sleepdata.org/datasets/shhs) 26 | 27 | This project currently only provides pre-processing code for SleepEDF-20 and SleepEDF-78, and only provides code for sample-wise k-fold cross-validation. We will update the code in the future. 28 | After downloading the datasets, please place them in the folder with the corresponding name in the directory `dataset`. 29 | You can run the `dataset_prepare.py` to extract events from the original record (.edf) 30 | 31 | ## Reproducibility 32 | If you want to update the training parameters, you can edit the `args.py` file. In this file, you can update: 33 | 34 | - Device (GPU or CPU). 35 | - Batch size. 36 | - Number of folds (as we use K-fold cross validation). 37 | - The number of training epochs. 38 | - Parameters in our model (dropout rate, number of transformer encoder, etc) 39 | 40 | To easily reproduce the results you can follow the next steps: 41 | 42 | 1. Run `dataset_prepare.py` to extract events from the original record (.edf). 43 | 2. Run `data_preprocess_TF` to preprocess the data. The original signals will be converted to time-frequency images, and normalized. 44 | 3. Run `Kfold_trainer.py` to perform the standard K-fold cross validation. 45 | 4. Run `result_evaluate.py` to get the evaluation report. It concludes the various valuation metrics we described in paper. 46 | 47 | 48 | ## Contact 49 | Yang Dai 50 | Center for Neurointelligence, School of Medicine 51 | Chongqing University, Chongqing 400030, China 52 | Email: valar_d@163.com 53 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | 5 | class Config(object): 6 | """args in model and trainer""" 7 | def __init__(self): 8 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 9 | self.num_fold = 10 10 | self.num_classes = 5 11 | self.num_epochs = 200 # Because early stopping is used, this parameter can be relatively large 12 | self.batch_size = 64 13 | self.pad_size = 29 # time dimension of TF image 14 | self.learning_rate = 5e-6 15 | self.dropout = 0.1 # dropout rate in transformer encoder 16 | self.dim_model = 128 # frequency of TF image 17 | self.forward_hidden = 1024 # hidden units of transformer encoder 18 | self.fc_hidden = 1024 # hidden units of FC layers 19 | self.num_head = 8 20 | self.num_encoder = 16 # number of encoders in single-channel feature extraction block 21 | self.num_encoder_multi = 4 # number of encoders in multi-channel feature fusion block 22 | 23 | 24 | class Path(object): 25 | """path of files in this project""" 26 | def __init__(self): 27 | self.path_PSG = 'dataset/sleepEDF-78/sleep-cassette' 28 | self.path_hypnogram = 'dataset/sleepEDF-78/Hypnogram' 29 | self.path_raw_data = 'data/sleepEDF-78/data_array/raw_data' 30 | self.path_labels = 'data/sleepEDF-78/data_array/raw_data/labels' 31 | self.path_TF = 'data/sleepEDF-78/data_array/TF_data' 32 | 33 | if not os.path.exists(self.path_hypnogram): 34 | os.makedirs(self.path_hypnogram) 35 | 36 | if not os.path.exists(self.path_raw_data): 37 | os.makedirs(self.path_raw_data) 38 | 39 | if not os.path.exists(self.path_TF): 40 | os.makedirs(self.path_TF) 41 | 42 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils.data import TensorDataset, DataLoader 6 | from sklearn.model_selection import train_test_split 7 | 8 | from args import Config, Path 9 | 10 | 11 | def data_generator(path_labels, path_dataset): 12 | config = Config() 13 | dir_annotation = os.listdir(path_labels) 14 | 15 | first = True 16 | for f in dir_annotation: 17 | if first: 18 | labels = np.load(os.path.join(path_labels, f)) 19 | first = False 20 | else: 21 | temp = np.load(os.path.join(path_labels, f)) 22 | labels = np.append(labels, temp, axis=0) 23 | labels = torch.from_numpy(labels) 24 | 25 | dataset_EEG_FpzCz = np.load(os.path.join(path_dataset, 'TF_EEG_Fpz-Cz_mean_std.npy')).astype('float32') 26 | dataset_EEG_PzOz = np.load(os.path.join(path_dataset, 'TF_EEG_Pz-Oz_mean_std.npy')).astype('float32') 27 | dataset_EOG = np.load(os.path.join(path_dataset, 'TF_EOG_mean_std.npy')).astype('float32') 28 | 29 | dataset = np.stack((dataset_EEG_FpzCz, dataset_EEG_PzOz, dataset_EOG), axis=1) 30 | dataset = torch.from_numpy(dataset) 31 | 32 | print('dataset: ', dataset.shape) 33 | 34 | # hold out the validation set 35 | X_train_test, X_val, y_train_test, y_val = train_test_split(dataset, labels, test_size=1/(config.num_fold+1), random_state=0, stratify=labels) 36 | 37 | val_set = TensorDataset(X_val, y_val) 38 | val_loader = DataLoader(dataset=val_set, batch_size=config.batch_size, shuffle=False) 39 | 40 | print('val_set:', len(X_val)) 41 | return X_train_test, y_train_test, val_loader -------------------------------------------------------------------------------- /data_preprocess_TF.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from scipy.fftpack import fft 5 | from scipy import signal 6 | from tqdm import tqdm 7 | 8 | from args import Path 9 | 10 | 11 | def data_array_concat(path_array): 12 | """concat data from each subject""" 13 | dir_PSG = os.listdir(path_array) 14 | first = True 15 | print('Preparing dataset:') 16 | for f in tqdm(dir_PSG): 17 | if first: 18 | data_channel = np.load(os.path.join(path_array, f)).astype('float32') 19 | first = False 20 | else: 21 | temp = np.load(os.path.join(path_array, f)).astype('float32') 22 | data_channel = np.append(data_channel, temp, axis=0) 23 | data_channel = np.squeeze(data_channel, axis=1) 24 | return data_channel 25 | 26 | 27 | def spectrogram(x, window, n_overlap, nfft): 28 | """ 29 | Transform to time-frequency images. This function imitates function spectrogram in Matlab 30 | Args: 31 | x (numpy array): Data 32 | window (int): Size of window function 33 | n_overlap (int):Number of coincidence points between two segments 34 | nfft (int): Number of points during Fast Fourier Transform 35 | """ 36 | len_x = len(x) 37 | step = window - n_overlap 38 | nn = nfft // 2 + 1 39 | num_win = int(np.floor((len_x - n_overlap) / (window - n_overlap))) 40 | spectrogram_data = [] 41 | # Hamming window default 42 | win = signal.hamming(window) 43 | for i in range(num_win): 44 | subdata = x[i * step: i * step + window] 45 | F = fft(subdata * win, n=nfft) 46 | spectrogram_data.append(F[:nn]) 47 | spectrogram_data = np.array(spectrogram_data) 48 | return spectrogram_data 49 | 50 | 51 | def data_normalize(dataset, channel): 52 | """normalize datasets of each channel to zero mean and unit variance""" 53 | for i in tqdm(range(dataset.shape[0])): 54 | if True in np.isinf(dataset[i]): 55 | for j in range(29): 56 | if True in np.isinf(dataset[i][j]): 57 | for k in range(128): 58 | if np.isinf(dataset[i][j][k]): 59 | if k != 127: 60 | print('location of inf: ', i, ',', j, ',', k) 61 | if k == 0: 62 | if j == 0: 63 | dataset[i][j][k] = dataset[i][j+1][k] 64 | else: 65 | dataset[i][j][k] = dataset[i][j-1][k] 66 | else: 67 | dataset[i][j][k] = dataset[i][j][k-1] 68 | 69 | dataset = (dataset - np.mean(dataset)) / np.std(dataset) 70 | 71 | ans1 = np.isinf(dataset) 72 | ans2 = np.isnan(dataset) 73 | 74 | if not ((True in ans1) and (True in ans2)): 75 | np.save('./data/sleepEDF-78/data_array/TF_data/TF_{}_mean_std.npy'.format(channel), dataset) 76 | 77 | 78 | if __name__ == '__main__': 79 | path = Path() 80 | 81 | fs = 100 82 | overlap = 1 83 | nfft = 256 84 | win_size = 2 85 | 86 | for channel in ['EEG_Fpz-Cz', 'EEG_Pz-Oz', 'EOG']: 87 | print('-' * 15, 'Processing channel:{}'.format(channel), '-' * 15) 88 | data_channel = data_array_concat(path_array=os.path.join(path.path_raw_data, channel)) 89 | X = np.zeros([data_channel.shape[0], 29, int(nfft / 2)]) 90 | print('Transform to TF images:') 91 | for i in tqdm(range(data_channel.shape[0])): 92 | Xi = spectrogram(data_channel[i, :], win_size * fs, overlap * fs, nfft) 93 | Xi = 20 * np.log10(abs(Xi)) 94 | X[i, :, :] = Xi[:, 1:129] 95 | 96 | print('Normalize:') 97 | data_normalize(dataset=X, channel=channel) 98 | -------------------------------------------------------------------------------- /dataset_prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import mne 5 | import numpy as np 6 | 7 | from args import Path 8 | 9 | 10 | def prepare_SleepEDF_20(path_PSG, path_hypnogram, save_path): 11 | """extract 30-s epoch from EDF files""" 12 | for file in os.listdir(path_PSG): 13 | if "Hypnogram" in file: 14 | original_path = os.path.join(path_PSG, file) 15 | target_path = os.path.join(path_hypnogram, file) 16 | shutil.move(original_path, target_path) 17 | 18 | annotation_desc_2_event_id = {'Sleep stage W': 1, 19 | 'Sleep stage 1': 2, 20 | 'Sleep stage 2': 3, 21 | 'Sleep stage 3': 4, 22 | 'Sleep stage 4': 4, 23 | 'Sleep stage R': 5} 24 | event_id = {'Sleep stage W': 1, 25 | 'Sleep stage 1': 2, 26 | 'Sleep stage 2': 3, 27 | 'Sleep stage 3/4': 4, 28 | 'Sleep stage R': 5} 29 | 30 | event_id_with_no_N3N4 = {'Sleep stage W': 1, 31 | 'Sleep stage 1': 2, 32 | 'Sleep stage 2': 3, 33 | 'Sleep stage R': 5} 34 | 35 | dir_PSG = os.listdir(path_PSG) 36 | dir_annotation = os.listdir(path_hypnogram) 37 | 38 | for i, j in zip(dir_PSG, dir_annotation): 39 | print('current file: ', i, j) 40 | 41 | PSG_file = os.path.join(path_PSG, i) 42 | annotation_file = os.path.join(path_hypnogram, j) 43 | 44 | raw_train = mne.io.read_raw_edf(PSG_file, stim_channel='marker', misc=['rectal']) 45 | annotation_train = mne.read_annotations(annotation_file) 46 | raw_train.set_annotations(annotation_train, emit_warning=False) 47 | 48 | annotation_train.crop(annotation_train[1]['onset'] - 30 * 60, annotation_train[-2]['onset'] + 30 * 60) 49 | raw_train.set_annotations(annotation_train, emit_warning=False) 50 | events_train, sleep_stage_exist = mne.events_from_annotations(raw_train, event_id=annotation_desc_2_event_id, 51 | chunk_duration=30.) 52 | 53 | tmax = 30. - 1. / raw_train.info['sfreq'] # tmax in included 54 | 55 | if len(sleep_stage_exist) <= 4: 56 | epochs_train = mne.Epochs(raw=raw_train, events=events_train, event_id=event_id_with_no_N3N4, tmin=0., 57 | tmax=tmax, baseline=None, preload=True) 58 | else: 59 | epochs_train = mne.Epochs(raw=raw_train, events=events_train, event_id=event_id, tmin=0., tmax=tmax, 60 | baseline=None, preload=True) 61 | 62 | X_train_eeg_FpzCz = epochs_train.copy().pick_channels(['EEG Fpz-Cz']).get_data() 63 | X_train_eeg_PzOz = epochs_train.copy().pick_channels(['EEG Pz-Oz']).get_data() 64 | X_train_eog = epochs_train.copy().pick_channels(['EOG horizontal']).get_data() 65 | y_train = epochs_train.copy().pick_channels(['EEG Fpz-Cz']).events[:, 2] 66 | y_train = y_train - 1 67 | 68 | for channel in ['EEG_Fpz-Cz', 'EEG_Pz-Oz', 'EOG', 'labels']: 69 | if not os.path.exists(os.path.join(save_path, channel)): 70 | os.makedirs(os.path.join(save_path, channel)) 71 | 72 | np.save(save_path + '/EEG_Fpz-Cz/{}_EEG_Fpz-Cz.npy'.format(i[0:12]), X_train_eeg_FpzCz) 73 | np.save(save_path + '/EEG_Pz-Oz/{}_EEG_Pz-Oz.npy'.format(i[0:12]), X_train_eeg_PzOz) 74 | np.save(save_path + '/EOG/{}_EOG.npy'.format(i[0:12]), X_train_eog) 75 | np.save(save_path + '/labels/{}_label.npy'.format(i[0:12]), y_train) 76 | 77 | 78 | if __name__ == '__main__': 79 | path = Path() 80 | prepare_SleepEDF_20(path_PSG=path.path_PSG, path_hypnogram=path.path_hypnogram, save_path=path.path_raw_data) 81 | 82 | 83 | -------------------------------------------------------------------------------- /early_stop_tool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | 5 | 6 | class EarlyStopping: 7 | """Early stop the training if validation acc doesn't improve after a given patience.""" 8 | def __init__(self, patience=20, verbose=False, delta=0, trace_func=print, save_all_checkpoint=False): 9 | """ 10 | Args: 11 | patience (int): How long to wait after last time validation loss improved. 12 | Default: 20 13 | verbose (bool): If True, prints a message for each validation loss improvement. 14 | Default: False 15 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 16 | Default: 0 17 | trace_func (function): trace print function. 18 | Default: print 19 | save_all_checkpoint (bool): If True, Save the model of all checkpoints during training, requires large storage space 20 | Default: False 21 | """ 22 | self.patience = patience 23 | self.verbose = verbose 24 | self.counter = 0 25 | self.best_score = None 26 | self.early_stop = False 27 | self.val_acc_max = np.Inf 28 | self.delta = delta 29 | self.trace_func = trace_func 30 | self.save_all_checkpoint = save_all_checkpoint 31 | 32 | def __call__(self, val_acc, model, path): 33 | 34 | score = val_acc 35 | 36 | if self.best_score is None: 37 | self.best_score = score 38 | self.save_checkpoint(val_acc, model, path) 39 | elif score < self.best_score + self.delta: 40 | self.counter += 1 41 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 42 | if self.counter >= self.patience: 43 | self.early_stop = True 44 | else: 45 | self.best_score = score 46 | self.save_checkpoint(val_acc, model, path) 47 | self.counter = 0 48 | 49 | def save_checkpoint(self, val_acc, model, path): 50 | """Saves model when validation loss decrease.""" 51 | if self.verbose: 52 | self.trace_func(f'Validation acc increased ({self.val_acc_max:.6f} --> {val_acc:.6f}). Saving model ...') 53 | if not self.save_all_checkpoint: 54 | path = os.path.join(os.path.dirname(path), 'model.pkl') 55 | torch.save(model.state_dict(), path) 56 | self.val_acc_max = val_acc -------------------------------------------------------------------------------- /imgs/.1-sunloginclient6963C79D-1100-4A1C-90C8-85045AAA29C5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangdai97/MultiChannelSleepNet/290daf0f5ec75f06a9a60b524f39bd22fb9725d7/imgs/.1-sunloginclient6963C79D-1100-4A1C-90C8-85045AAA29C5 -------------------------------------------------------------------------------- /imgs/.3-sunloginclient85CB63AE-0D21-4D56-9D9B-6C9C33E079F1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangdai97/MultiChannelSleepNet/290daf0f5ec75f06a9a60b524f39bd22fb9725d7/imgs/.3-sunloginclient85CB63AE-0D21-4D56-9D9B-6C9C33E079F1 -------------------------------------------------------------------------------- /imgs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangdai97/MultiChannelSleepNet/290daf0f5ec75f06a9a60b524f39bd22fb9725d7/imgs/.DS_Store -------------------------------------------------------------------------------- /imgs/MultiChannelSleepNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangdai97/MultiChannelSleepNet/290daf0f5ec75f06a9a60b524f39bd22fb9725d7/imgs/MultiChannelSleepNet.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | """Implement the PE function.""" 10 | 11 | def __init__(self, d_model=128, dropout=0.2, max_len=30): 12 | super(PositionalEncoding, self).__init__() 13 | self.dropout = nn.Dropout(p=dropout) 14 | 15 | # Compute the positional encodings once in log space. 16 | pe = torch.zeros(max_len, d_model) 17 | position = torch.arange(0., max_len).unsqueeze(1) 18 | div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model)) 19 | pe[:, 0::2] = torch.sin(position * div_term) 20 | pe[:, 1::2] = torch.cos(position * div_term) 21 | pe = pe.unsqueeze(0) # pe:[1, 30, 128] 22 | self.register_buffer('pe', pe) 23 | 24 | def forward(self, x): 25 | x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) 26 | return self.dropout(x) 27 | 28 | 29 | class Transformer(nn.Module): 30 | def __init__(self, config): 31 | super(Transformer, self).__init__() 32 | 33 | self.position_single = PositionalEncoding(d_model=config.dim_model, dropout=0.1) 34 | 35 | encoder_layer = nn.TransformerEncoderLayer(d_model=config.dim_model, nhead=config.num_head, dim_feedforward=config.forward_hidden, dropout=config.dropout) 36 | self.transformer_encoder_1 = nn.TransformerEncoder(encoder_layer, num_layers=config.num_encoder) 37 | self.transformer_encoder_2 = nn.TransformerEncoder(encoder_layer, num_layers=config.num_encoder) 38 | self.transformer_encoder_3 = nn.TransformerEncoder(encoder_layer, num_layers=config.num_encoder) 39 | 40 | self.drop = nn.Dropout(p=0.5) 41 | self.layer_norm = nn.LayerNorm(config.dim_model * 3) 42 | 43 | self.position_multi = PositionalEncoding(d_model=config.dim_model * 3, dropout=0.1) 44 | encoder_layer_multi = nn.TransformerEncoderLayer(d_model=config.dim_model * 3, nhead=config.num_head,dim_feedforward=config.forward_hidden, dropout=config.dropout) 45 | self.transformer_encoder_multi = nn.TransformerEncoder(encoder_layer_multi, num_layers=config.num_encoder_multi) 46 | 47 | self.fc1 = nn.Sequential( 48 | nn.Linear(config.pad_size * config.dim_model * 3, config.fc_hidden), 49 | nn.ReLU(), 50 | nn.Dropout(p=0.5) 51 | ) 52 | self.fc2 = nn.Sequential( 53 | nn.Linear(config.fc_hidden, config.num_classes) 54 | ) 55 | 56 | def forward(self, x): 57 | x1 = x[:, 0, :, :] 58 | x2 = x[:, 1, :, :] 59 | x3 = x[:, 2, :, :] 60 | x1 = self.position_single(x1) 61 | x2 = self.position_single(x2) 62 | x3 = self.position_single(x3) 63 | 64 | x1 = self.transformer_encoder_1(x1) # (batch_size, 29, 128) 65 | x2 = self.transformer_encoder_2(x2) 66 | x3 = self.transformer_encoder_3(x3) 67 | 68 | x = torch.cat([x1, x2, x3], dim=2) 69 | 70 | x = self.drop(x) 71 | x = self.layer_norm(x) 72 | residual = x 73 | 74 | x = self.position_multi(x) 75 | x = self.transformer_encoder_multi(x) 76 | 77 | x = self.layer_norm(x + residual) # residual connection 78 | 79 | x = x.view(x.size(0), -1) 80 | x = self.fc1(x) 81 | x = self.fc2(x) 82 | return x -------------------------------------------------------------------------------- /result_evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from sklearn.metrics import recall_score, accuracy_score, f1_score, cohen_kappa_score 5 | from sklearn.metrics import confusion_matrix 6 | from sklearn.model_selection import StratifiedKFold 7 | 8 | import torch 9 | from torch.utils.data import TensorDataset, DataLoader 10 | from torch.autograd import Variable 11 | 12 | from model import Transformer 13 | from data_loader import data_generator 14 | from args import Config, Path 15 | 16 | 17 | def specificity(y_true, y_pred, n=5): 18 | spec = [] 19 | con_mat = confusion_matrix(y_true, y_pred) # Each row is the ground truth, and each column is the precision 20 | for i in range(n): 21 | number = np.sum(con_mat[:, :]) 22 | tp = con_mat[i][i] 23 | fn = np.sum(con_mat[i, :]) - tp 24 | fp = np.sum(con_mat[:, i]) - tp 25 | tn = number - tp - fn - fp 26 | spec1 = tn / (tn + fp) 27 | spec.append(spec1) 28 | average_specificity = np.mean(spec) 29 | return average_specificity 30 | 31 | 32 | def class_wise_evaluate(con_mat): 33 | """ 34 | Calculate the class_wise result through the confusion matrix 35 | Rows: Wake, N1, N2, N3 36 | Columns: precision, recall, F1_ score 37 | """ 38 | class_wise_mat = np.empty((5, 3)) 39 | for i in range(5): 40 | precision = con_mat[i, i] / np.sum(con_mat[:, i]) 41 | recall = con_mat[i, i] / np.sum(con_mat[i, :]) 42 | F1_score = (2 * precision * recall) / (precision + recall) 43 | class_wise_mat[i, 0] = precision 44 | class_wise_mat[i, 1] = recall 45 | class_wise_mat[i, 2] = F1_score 46 | 47 | return class_wise_mat 48 | 49 | 50 | def test(model, test_loader, config): 51 | model.eval() 52 | 53 | pred = [] 54 | label = [] 55 | 56 | with torch.no_grad(): 57 | loop = tqdm(enumerate(test_loader), total=len(test_loader)) 58 | for batch_idx, (data, target) in loop: 59 | data = data.to(config.device) 60 | target = target.to(config.device) 61 | data, target = Variable(data), Variable(target) 62 | 63 | output = model(data) 64 | 65 | pred.extend(np.argmax(output.data.cpu().numpy(), axis=1)) 66 | label.extend(target.data.cpu().numpy()) 67 | 68 | accuracy = accuracy_score(label, pred, normalize=True, sample_weight=None) 69 | cohens_kappa = cohen_kappa_score(label, pred) 70 | macro_f1 = f1_score(label, pred, average='macro') 71 | average_sensitivity = recall_score(label, pred, average="macro") # sensitivity and recall are the same concept 72 | average_specificity = specificity(label, pred, n=5) 73 | 74 | print('ACC: %.4f' % accuracy, 'k: %.4f' % cohens_kappa, 'MF1: %.4f' % macro_f1, 75 | 'Sens: %.4f' % average_sensitivity, 'Spec: %.4f' % average_specificity) 76 | 77 | con_mat = confusion_matrix(label, pred) 78 | 79 | return accuracy, cohens_kappa, macro_f1, average_sensitivity, average_specificity, con_mat 80 | 81 | 82 | def evaluate(config, path): 83 | dataset, labels, val_loader = data_generator(path_labels=path.path_labels, path_dataset=path.path_TF) 84 | 85 | kf = StratifiedKFold(n_splits=config.num_fold, shuffle=True, random_state=0) 86 | 87 | ACC = 0 88 | Kappa = 0 89 | MF1 = 0 90 | Sens = 0 91 | Spec = 0 92 | Confusion_mat = np.zeros([5, 5]) 93 | 94 | for fold, (train_idx, test_idx) in enumerate(kf.split(dataset, labels)): 95 | print('-' * 15, '>', f'Fold {fold}', '<', '-' * 15) 96 | 97 | path_model = './Kfold_models/fold{}/model.pkl'.format(fold) 98 | 99 | _, X_test = dataset[train_idx], dataset[test_idx] 100 | _, y_test = labels[train_idx], labels[test_idx] 101 | test_set = TensorDataset(X_test, y_test) 102 | test_loader = DataLoader(dataset=test_set, batch_size=config.batch_size, shuffle=False) 103 | 104 | print('train_set: ', len(train_idx)) 105 | print('test_set: ', len(test_idx)) 106 | 107 | model = Transformer(config) 108 | model = model.to(config.device) 109 | model.load_state_dict(torch.load(path_model), strict=True) 110 | 111 | accuracy, cohens_kappa, macro_f1, average_sensitivity, average_specificity, con_mat = test(model, test_loader, config) 112 | 113 | ACC += accuracy 114 | Kappa += cohens_kappa 115 | MF1 += macro_f1 116 | Sens += average_sensitivity 117 | Spec += average_specificity 118 | 119 | Confusion_mat += con_mat 120 | 121 | del model 122 | 123 | ACC /= config.num_fold 124 | Kappa /= config.num_fold 125 | MF1 /= config.num_fold 126 | Sens /= config.num_fold 127 | Spec /= config.num_fold 128 | 129 | class_wise_result = class_wise_evaluate(Confusion_mat) 130 | 131 | return ACC, Kappa, MF1, Sens, Spec, Confusion_mat, class_wise_result 132 | 133 | 134 | if __name__ == '__main__': 135 | config = Config() 136 | path = Path() 137 | 138 | ACC, Kappa, MF1, Sens, Spec, Confusion_mat, class_wise_result = evaluate(config=config, path=path) 139 | 140 | print('ACC: ', ACC) 141 | print('Cohen\'s Kappa: ', Kappa) 142 | print('MF1: ', MF1) 143 | print('Sens: ', Sens) 144 | print('Spec: ', Spec) 145 | print('confusion_mat:') 146 | print(Confusion_mat) 147 | print('class_wise_result: ') 148 | print(class_wise_result) 149 | --------------------------------------------------------------------------------