├── README.assets └── model.png ├── README.md ├── preprocess ├── run_preprocess.py └── Utils.py ├── Utils └── two_stream_dataloader.py ├── train.py └── model.py /README.assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiyangcai/SleepPrintNet/HEAD/README.assets/model.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SleepPrintNet 2 | 3 | SleepPrintNet: A Multivariate Multimodal Neural Network based on Physiological Time-series for Automatic Sleep Staging 4 | 5 | ![model](README.assets/model.png) 6 | 7 | The SleepPrintNet is made up of four independent modules: EEG temporal feature extraction module, EEG spectral-spatial feature extraction module, EOG feature extraction module, and EMG feature extraction module. After extracting the corresponding features from each module, we use a feature fusion layer to fuse all the features, and finally obtain the classification results through a classification layer. 8 | 9 | These are the source code of SleepPrintNet. 10 | 11 | # Dataset 12 | 13 | We evaluate our model on the [Montreal Archive of Sleep Studies (MASS) ](http://ceams-carsm.ca/en/mass/) dataset Subset-3. The MASS is an open-access and collaborative database of laboratory-based polysomnography (PSG) recordings. 14 | 15 | # Requirements 16 | 17 | - Python 3.6.5 18 | - CUDA 9.0 19 | - CuDNN 7.5.1 20 | - numpy==1.15.0 21 | - sklearn==0.19.1 22 | - tensorflow_gpu==1.8.0 23 | - Keras==2.2.0 24 | - matplotlib==3.0.3 25 | 26 | # Usage 27 | 28 | - Preprocess 29 | 30 | - Prepare data (raw signals) in `data_dir` 31 | - Files name: 01-03-00**XX**-Data.npy, where **XX** denotes subject ID. 32 | - Tensor shape: **[sample, channel=26, length]**, where channel 0 is ECG channel, 1-20 are EEG channels, 21-23 are EMG channels, 24-25 are EOG channels. 33 | - Modify directory 34 | - Modify `data_dir` and `label_dir` in `run_preprocess.py` accordingly. 35 | - (Optional) Modify `output_dir` in `run_preprocess.py`. 36 | - Run preprocess program 37 | - `cd preprocess` 38 | - `python run_preprocess.py` 39 | 40 | - Command Line Parameters 41 | 42 | - Training 43 | - `--batch_size`: Training batch size. 44 | - `--epoch`: Number of training epochs. 45 | - `--num_fold`: Number of folds. 46 | - `--save_model`: Save the best model or not. 47 | - Directory 48 | - `--model_dir`: The directory for saving best models of each fold. 49 | - `--data_dir1`: The directory of the EEG, EOG signals. 50 | - `--data_dir2`: The directory of the EMG signals. 51 | - `--data_dir3`: The directory of the spectral-spatial representation of EEG signals. 52 | - `--result_dir`: The directory for saving results. 53 | 54 | - Input Data Shape 55 | 56 | - EEG_EOG 57 | - Data**SubID**.npy : (numOfSamples, numOfChannels, timeLength) -> (numOfSamples, 6 + 2, 30 * 128 * 3) 58 | - numOFChannels: 6 channels EEG signals and 2 channels EOG signals. 59 | - timeLength: 30 (s) * 128 (Hz) * 3 (epochs) = 11520 60 | - Label**SubID**.npy: (numOfSamples, ) 61 | - EMG 62 | - Data**SubID**.npy : (numOfSamples, numOfChannels, timeLength) -> (numOfSamples, 3, 11520) 63 | - numOFChannels: 3 channels EMG signals. 64 | - timeLength: 30 (s) * 128 (Hz) * 3 (epochs) = 11520 65 | - Label**SubID**.npy: (numOfSamples, ) 66 | - fre_spa 67 | - Data**SubID**.npy : (numOfSamples, numOfFreqBands, height, width, 1) -> (numOfSamples, 5, 16, 16, 1) 68 | - numOfFreqBands: 5 frequency bands 69 | - height, width: 16 px $\times$ 16 px 70 | - Label**SubID**.npy: (numOfSamples, ) 71 | 72 | - Training 73 | 74 | Run `run.py` with the command line parameters. By default, the model can be run with the following command: 75 | 76 | ``` 77 | CUDA_VISIBLE_DEVICES=0 python run.py 78 | ``` 79 | -------------------------------------------------------------------------------- /preprocess/run_preprocess.py: -------------------------------------------------------------------------------- 1 | from numpy.core.defchararray import add 2 | from Utils import * 3 | import os 4 | import glob 5 | 6 | stft_para = { 7 | 'stftn': 3840, 8 | 'fStart': [1, 4, 8, 14, 31], 9 | 'fEnd': [3, 7, 13, 30, 50], 10 | 'fs': 128, 11 | 'window': 30, 12 | } 13 | 14 | data_dir = '../SS3/data/' 15 | label_dir = '../SS3/label/' 16 | 17 | output_dir = '../data/' 18 | output_dir_fre_spa = os.path.join(output_dir, 'fre_spa') 19 | output_dir_EEG_EOG = os.path.join(output_dir, 'EEG_EOG') 20 | output_dir_EMG = os.path.join(output_dir, 'EMG') 21 | 22 | def GenerateTopographic(data, sub_id): 23 | ''' 24 | data: NumPy Tensor (Sample, EEG Channels, Time Length) 25 | sub_id: int 26 | ''' 27 | data = butter_bandpass_filter(data, 0.5, 50, 128, 4) 28 | print("Temporal data shape:", data.shape) # Data shape (1005, 20, 7680) 29 | 30 | psd_frequency = [] 31 | for sample in range(data.shape[0]): 32 | psd, _ = DE_PSD(data[sample], stft_para) # PSD shape (20, 5) 33 | psd_frequency.append(psd) 34 | psd_frequency = np.array(psd_frequency) 35 | print("PSD data shape:", psd_frequency.shape) 36 | MY_frequency = norm(psd_frequency) 37 | 38 | heatmap = convert_heat(MY_frequency) 39 | 40 | heatmap_spectral = np.zeros( 41 | [MY_frequency.shape[0], 5, 16, 16], dtype='float32') 42 | for ep in range(heatmap.shape[0]): 43 | for hz in range(heatmap.shape[1]): 44 | heatmap_spectral[ep, hz, :, :] = grid_data(heatmap[ep][hz]) 45 | heatmap_spectral = heatmap_spectral[:, :, :, :, np.newaxis] 46 | return heatmap_spectral 47 | 48 | def main(): 49 | if os.path.exists(output_dir) is not True: 50 | os.mkdir(output_dir) 51 | if os.path.exists(output_dir_fre_spa) is not True: 52 | os.mkdir(output_dir_fre_spa) 53 | if os.path.exists(output_dir_EEG_EOG) is not True: 54 | os.mkdir(output_dir_EEG_EOG) 55 | if os.path.exists(output_dir_EMG) is not True: 56 | os.mkdir(output_dir_EMG) 57 | 58 | for sub_id in range(1, 65): 59 | print(f'Subject {sub_id}') 60 | if sub_id in (43, 49): # Exclude 43 and 49 61 | continue 62 | 63 | data = np.load(os.path.join(data_dir, f'01-03-00{sub_id:02}-Data.npy')) 64 | # data's shape: [sample, channel, length] where 65 | # CHANNEL: 0 -> ECG channel, 1-20 -> EEG channels, 66 | # 21-23 -> EMG channels, 24->25 EOG channels 67 | label = np.load(os.path.join(label_dir, f'subject{sub_id}.npy')) 68 | 69 | 70 | # Prepare Spectral Spatial Representation of EEG signals 71 | EEG = data[:, 1:21, :] # Select EEG channels 72 | Fre_Spa_Representation = GenerateTopographic(EEG, sub_id) 73 | 74 | # Prepare raw EEG and EOG signals 75 | EEG_EOG_Channels = [4, 5, 1, 2, 11, 12, 24, 25] 76 | EEG_EOG_Representation = data[:, EEG_EOG_Channels, :] 77 | 78 | # Prepare raw EMG signals 79 | EMG_Representation = data[:, 21:24, :] 80 | 81 | Fre_Spa_Representation = AddContext(Fre_Spa_Representation, add_context=False).astype(np.float32) 82 | EEG_EOG_Representation = AddContext(EEG_EOG_Representation, add_context=True).astype(np.float32) 83 | EMG_Representation = AddContext(EMG_Representation, add_context=True).astype(np.float32) 84 | label = AddContext(label, add_context=False) 85 | 86 | np.save(os.path.join(output_dir_fre_spa, f'Data{sub_id}'), Fre_Spa_Representation) 87 | np.save(os.path.join(output_dir_fre_spa, f'Label{sub_id}'), label) 88 | np.save(os.path.join(output_dir_EEG_EOG, f'Data{sub_id}'), EEG_EOG_Representation) 89 | np.save(os.path.join(output_dir_EEG_EOG, f'Label{sub_id}'), label) 90 | np.save(os.path.join(output_dir_EMG, f'Data{sub_id}'), EMG_Representation) 91 | np.save(os.path.join(output_dir_EMG, f'Label{sub_id}'), label) 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /Utils/two_stream_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | class SeqDataLoader(): 6 | def __init__(self, data_dir, n_folds, fold_idx, classes, n_files): 7 | self.data_dir = data_dir 8 | self.n_folds = n_folds 9 | self.fold_idx = fold_idx 10 | self.classes = classes 11 | self.n_files = n_files 12 | 13 | def _load_npy_list_files(self, data_files, label_files): 14 | data = [] 15 | labels = [] 16 | for data_name, label_name in zip(data_files, label_files): 17 | #print ("Loading {} {} ...".format(data_name,label_name)) 18 | tmp_data = np.load(data_name) 19 | tmp_labels = np.load(label_name) 20 | tmp_labels = tmp_labels.astype(int) 21 | data.append(tmp_data) 22 | labels.append(tmp_labels) 23 | return data, labels 24 | 25 | def print_n_samples_each_class(self, labels, classes): 26 | class_dict = dict(zip(range(len(classes)), classes)) 27 | unique_labels = np.unique(labels) 28 | for c in unique_labels: 29 | n_samples = len(np.where(labels == c)[0]) 30 | print("{}: {}".format(class_dict[c], n_samples)) 31 | 32 | def load_data(self, shuffle=False): 33 | 34 | allfiles = os.listdir(self.data_dir) 35 | npyfiles = [] 36 | for f in allfiles: 37 | if ".npy" in f: 38 | npyfiles.append(os.path.join(self.data_dir, f)) 39 | 40 | npyfiles.sort(key=lambda x: (len(x), x)) 41 | 42 | datafiles = npyfiles[:len(npyfiles)//2] 43 | labelfiles = npyfiles[len(npyfiles)//2:] 44 | datafiles = datafiles[:self.n_files] 45 | labelfiles = labelfiles[:self.n_files] 46 | 47 | # Divide Training & Testing Sets 48 | r_permute = np.random.permutation(len(datafiles)) 49 | filename = os.path.join("r_permute{}.npz".format(len(datafiles))) 50 | if (os.path.isfile(filename)): 51 | with np.load(filename) as f: 52 | print("already exist") 53 | r_permute = f["inds"] 54 | else: 55 | save_dict = { 56 | "inds": r_permute, 57 | } 58 | np.savez(filename, **save_dict) 59 | 60 | datafiles = np.asarray(datafiles)[r_permute] 61 | labelfiles = np.asarray(labelfiles)[r_permute] 62 | traindata_files = np.array_split(datafiles, self.n_folds) 63 | trainlabel_files = np.array_split(labelfiles, self.n_folds) 64 | subjectdata_files = traindata_files[self.fold_idx] 65 | subjectlabel_files = trainlabel_files[self.fold_idx] 66 | traindata_files = list(set(datafiles) - set(subjectdata_files)) 67 | trainlabel_files = list(set(labelfiles) - set(subjectlabel_files)) 68 | traindata_files.sort(key=lambda x: (len(x), x)) 69 | trainlabel_files.sort(key=lambda x: (len(x), x)) 70 | 71 | # Load training and validation sets 72 | print("\n========== [Fold-{}] ==========\n".format(self.fold_idx)) 73 | print("Load training set:") 74 | data_train, label_train = self._load_npy_list_files( 75 | traindata_files, trainlabel_files) 76 | print(" ") 77 | print("Load Test set:") 78 | data_test, label_test = self._load_npy_list_files( 79 | subjectdata_files, subjectlabel_files) 80 | print(" ") 81 | print("Training set: n_subjects={}".format(len(data_train))) 82 | n_train_examples = 0 83 | for d in data_train: 84 | n_train_examples += d.shape[0] 85 | print("Number of examples = {}".format(n_train_examples)) 86 | self.print_n_samples_each_class(np.hstack(label_train), self.classes) 87 | print(" ") 88 | print("Test set: n_subjects = {}".format(len(data_test))) 89 | n_test_examples = 0 90 | for d in data_test: 91 | n_test_examples += d.shape[0] 92 | print("Number of examples = {}".format(n_test_examples)) 93 | self.print_n_samples_each_class(np.hstack(label_test), self.classes) 94 | print(" ") 95 | 96 | data_train = np.vstack(data_train) 97 | label_train = np.hstack(label_train) 98 | 99 | data_test = np.vstack(data_test) 100 | label_test = np.hstack(label_test) 101 | 102 | if shuffle is True: 103 | # training data 104 | permute = np.random.permutation(len(label_train)) 105 | data_train = np.asarray(data_train) 106 | data_train = data_train[permute] 107 | label_train = label_train[permute] 108 | 109 | return data_train, label_train, data_test, label_test 110 | -------------------------------------------------------------------------------- /preprocess/Utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from scipy.fftpack import fft 4 | from scipy.signal import butter, lfilter 5 | from scipy.interpolate import griddata 6 | 7 | 8 | def DE_PSD(data, stft_para): 9 | ''' 10 | input: data [n*m] n electrodes, m time points 11 | stft_para.stftn frequency domain sampling rate 12 | stft_para.fStart start frequency of each frequency band 13 | stft_para.fEnd end frequency of each frequency band 14 | stft_para.window window length of each sample point(seconds) 15 | stft_para.fs original frequency 16 | output:psd,DE [n*l*k] n electrodes, l windows, k frequency bands 17 | ''' 18 | 19 | # Initialize the parameters 20 | STFTN = stft_para['stftn'] 21 | fStart = stft_para['fStart'] 22 | fEnd = stft_para['fEnd'] 23 | fs = stft_para['fs'] 24 | window = stft_para['window'] 25 | 26 | fStartNum = np.zeros([len(fStart)], dtype=int) 27 | fEndNum = np.zeros([len(fEnd)], dtype=int) 28 | for i in range(0, len(stft_para['fStart'])): 29 | fStartNum[i] = int(fStart[i]/fs*STFTN) 30 | fEndNum[i] = int(fEnd[i]/fs*STFTN) 31 | 32 | n = data.shape[0] 33 | m = data.shape[1] 34 | 35 | # print(m,n,l) 36 | psd = np.zeros([n, len(fStart)]) 37 | de = np.zeros([n, len(fStart)]) 38 | # Hanning window 39 | Hlength = window*fs 40 | # Hwindow=hanning(Hlength); 41 | Hwindow = np.array([0.5 - 0.5 * np.cos(2 * np.pi * n / (Hlength+1)) 42 | for n in range(1, Hlength+1)]) 43 | 44 | WindowPoints = fs*window 45 | dataNow = data[0:n] 46 | for j in range(0, n): 47 | temp = dataNow[j] 48 | Hdata = temp * Hwindow 49 | FFTdata = fft(Hdata, STFTN) 50 | magFFTdata = abs(FFTdata[0: int(STFTN/2)]) 51 | for p in range(0, len(fStart)): 52 | E = 0 53 | E_log = 0 54 | for p0 in range(fStartNum[p]-1, fEndNum[p]): 55 | E = E + magFFTdata[p0] * magFFTdata[p0] 56 | E = E / (fEndNum[p] - fStartNum[p] + 1) 57 | psd[j][p] = E 58 | de[j][p] = math.log(100*E, 2) 59 | 60 | return psd, de 61 | 62 | 63 | def butter_bandpass(lowcut, highcut, fs, order): 64 | nyq = 0.5 * fs 65 | low = lowcut / nyq 66 | high = highcut / nyq 67 | b, a = butter(order, [low, high], btype='band') 68 | return b, a 69 | 70 | 71 | def butter_bandpass_filter(data, lowcut, highcut, fs, order): 72 | b, a = butter_bandpass(lowcut, highcut, fs, order=order) 73 | y = lfilter(b, a, data) 74 | return y 75 | 76 | def grid_data(data): 77 | grid_x, grid_y = np.mgrid[0:4:16j, 0:4:16j] 78 | points = [] 79 | for i in range(5): 80 | for j in range(5): 81 | points.append([i, j]) 82 | values = [] 83 | for x in points: 84 | values.append(data[x[0]][x[1]][0]) 85 | points = np.array(points) 86 | values = np.array(values) 87 | grid_z = griddata(points, values, (grid_x, grid_y), method='cubic') 88 | return grid_z 89 | 90 | def norm(pxx): 91 | mean = pxx.mean(axis=-2) 92 | std = pxx.std(axis=-2) 93 | for i in range(pxx.shape[0]): 94 | for j in range(pxx.shape[1]): 95 | for k in range(pxx.shape[2]): 96 | pxx[i][j][k] -= mean[i][k] 97 | for i in range(pxx.shape[0]): 98 | for j in range(pxx.shape[1]): 99 | for k in range(pxx.shape[2]): 100 | pxx[i][j][k] /= std[i][k] 101 | return pxx 102 | 103 | def convert_heat(pxx): 104 | heatmap = np.zeros([pxx.shape[0], pxx.shape[2], 5, 5, 1]) 105 | for ep in range(pxx.shape[0]): 106 | for hz in range(pxx.shape[2]): 107 | 108 | heatmap[ep][hz][0][1][0] = pxx[ep][9][hz] 109 | heatmap[ep][hz][0][3][0] = pxx[ep][16][hz] 110 | 111 | heatmap[ep][hz][1][0][0] = pxx[ep][8][hz] 112 | heatmap[ep][hz][1][1][0] = pxx[ep][14][hz] 113 | heatmap[ep][hz][1][2][0] = pxx[ep][12][hz] 114 | heatmap[ep][hz][1][3][0] = pxx[ep][11][hz] 115 | heatmap[ep][hz][1][4][0] = pxx[ep][0][hz] 116 | 117 | heatmap[ep][hz][2][0][0] = pxx[ep][19][hz] 118 | heatmap[ep][hz][2][1][0] = pxx[ep][7][hz] 119 | heatmap[ep][hz][2][2][0] = pxx[ep][5][hz] 120 | heatmap[ep][hz][2][3][0] = pxx[ep][3][hz] 121 | heatmap[ep][hz][2][4][0] = pxx[ep][13][hz] 122 | 123 | heatmap[ep][hz][3][0][0] = pxx[ep][4][hz] 124 | heatmap[ep][hz][3][1][0] = pxx[ep][2][hz] 125 | heatmap[ep][hz][3][2][0] = pxx[ep][17][hz] 126 | heatmap[ep][hz][3][3][0] = pxx[ep][6][hz] 127 | heatmap[ep][hz][3][4][0] = pxx[ep][10][hz] 128 | 129 | heatmap[ep][hz][4][1][0] = pxx[ep][18][hz] 130 | heatmap[ep][hz][4][2][0] = pxx[ep][15][hz] 131 | heatmap[ep][hz][4][3][0] = pxx[ep][1][hz] 132 | return heatmap 133 | 134 | # Add context to the origin data and label 135 | def AddContext(x, add_context=False): 136 | ''' 137 | Input: 138 | x: A tensor whose first axis is number of sample -> (samples, channel, length) 139 | Output: 140 | (n_sample, 3, n_channels, n_times) 141 | ''' 142 | if add_context: 143 | samples, channel, length = x.shape 144 | x = x[:, np.newaxis, :, :] 145 | ContextData = [] 146 | for cur_epoch in range(1, x.shape[0] - 1): 147 | cur_epoch_data = x[cur_epoch] 148 | former_epoch = x[cur_epoch - 1] 149 | latter_epoch = x[cur_epoch + 1] 150 | 151 | temporal_epoch = np.concatenate([former_epoch, cur_epoch_data, latter_epoch], axis=0) 152 | ContextData.append(temporal_epoch) 153 | ContextData = np.array(ContextData).swapaxes(1, 2) 154 | ContextData = ContextData.reshape([samples - 2, channel, 3 * length]) 155 | else: 156 | x = x[1: x.shape[0] - 1] 157 | ContextData = x 158 | return ContextData -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import model as SleepPrintNet 4 | 5 | import time 6 | import argparse 7 | import keras 8 | from keras import callbacks 9 | from keras import metrics 10 | from Utils.two_stream_dataloader import * 11 | from sklearn.metrics import confusion_matrix 12 | import matplotlib.pyplot as plt 13 | import tensorflow as tf 14 | from keras.utils import multi_gpu_model 15 | 16 | gpunums = len(os.environ["CUDA_VISIBLE_DEVICES"].split(',')) 17 | config = tf.ConfigProto() 18 | config.gpu_options.allow_growth = True # allocate dynamically 19 | config.gpu_options.per_process_gpu_memory_fraction = 0.9 20 | keras.backend.tensorflow_backend.set_session(tf.Session(config=config)) 21 | 22 | # Global vars 23 | fold_string = None 24 | best_acc = 0 25 | model = None 26 | X_test_6 = None 27 | y_test_6 = None 28 | X_test_16 = None 29 | y_test_16 = None 30 | X_test_psd = None 31 | X_test_eog = None 32 | y_test_psd = None 33 | file_dir = None 34 | filename = None 35 | model_dir = None 36 | real_fold = 1 37 | n_oversampling = 0 38 | over_cm = [] 39 | 40 | 41 | class LossHistory(callbacks.Callback): 42 | def on_train_begin(self, logs={}): 43 | self.losses = {'batch': [], 'epoch': []} 44 | self.accuracy = {'batch': [], 'epoch': []} 45 | self.val_loss = {'batch': [], 'epoch': []} 46 | self.val_acc = {'batch': [], 'epoch': []} 47 | 48 | def on_epoch_end(self, epoch, logs={}): 49 | 50 | self.losses['epoch'].append(logs.get('loss')) 51 | self.accuracy['epoch'].append(logs.get('acc')) 52 | self.val_loss['epoch'].append(logs.get('val_loss')) 53 | self.val_acc['epoch'].append(logs.get('val_acc')) 54 | self.draw_p2in1( 55 | self.losses['epoch'], self.val_loss['epoch'], 'loss', 'train_epoch', 'val_epoch') 56 | self.draw_p2in1( 57 | self.accuracy['epoch'], self.val_acc['epoch'], 'acc', 'train_epoch', 'val_epoch') 58 | global best_acc, over_cm 59 | 60 | # Output best 61 | if best_acc < max(self.val_acc['epoch']): 62 | best_acc = max(self.val_acc['epoch']) 63 | predict_test = np.argmax(model.predict( 64 | [X_test_psd, X_test_16, X_test_6, X_test_eog]), axis=1) 65 | cm = confusion_matrix(y_test_6, predict_test, 66 | labels=[0, 1, 2, 3, 4]) 67 | print(cm) 68 | over_cm[-1] = cm 69 | np.savetxt(filename+'.txt', cm, "%d") 70 | f = open(filename + '_best_acc.txt', "w") 71 | print(best_acc, file=f) 72 | f.close() 73 | print("acc", best_acc) 74 | 75 | def draw_p2in1(self, lists1, lists2, label, type1, type2): 76 | plt.figure() 77 | plt.plot(range(len(lists1)), lists1, 'r', label=type1) 78 | plt.plot(range(len(lists2)), lists2, 'b', label=type2) 79 | plt.ylabel(label) 80 | plt.xlabel(type1.split('_')[0]+'_'+type2.split('_')[0]) 81 | plt.legend(loc="upper right") 82 | global filename 83 | filename = file_dir+label+'_fold'+fold_string 84 | plt.savefig(filename+'.jpg') 85 | plt.close() 86 | 87 | def draw_p(self, lists, label, type): 88 | plt.figure() 89 | plt.plot(range(len(lists)), lists, 'r', label=label) 90 | plt.ylabel(label) 91 | plt.xlabel(type) 92 | plt.legend(loc="upper right") 93 | plt.savefig(filename+'.jpg') 94 | plt.close() 95 | 96 | def end_draw(self): 97 | self.draw_p2in1( 98 | self.losses['epoch'], self.val_loss['epoch'], 'loss', 'train_epoch', 'val_epoch') 99 | self.draw_p2in1( 100 | self.accuracy['epoch'], self.val_acc['epoch'], 'acc', 'train_epoch', 'val_epoch') 101 | 102 | 103 | def train(args): 104 | global fold_string, best_acc, model, file_dir, over_cm, X_test_16, X_test_6, y_test_16, y_test_6, X_test_psd, X_test_eog, y_test_psd 105 | 106 | classes = [0, 1, 2, 3, 4] 107 | num_classes = len(classes) 108 | num_folds = args.num_fold 109 | data_dir = args.data_dir1 110 | data_dir2 = args.data_dir2 111 | data_dir3 = args.data_dir3 112 | n_files = args.n_files 113 | model_dir = args.model_dir 114 | file_dir = args.result_dir 115 | 116 | seq_len = args.seqLen 117 | width = args.height 118 | height = args.width 119 | save_model = True if args.save_model else False 120 | 121 | all_time = 0 122 | acc = np.zeros(real_fold) 123 | 124 | if not os.path.exists(file_dir): 125 | os.mkdir(file_dir) 126 | if not os.path.exists(model_dir): 127 | os.mkdir(model_dir) 128 | 129 | for fold_idx in range(real_fold): 130 | over_cm.append([]) 131 | best_acc = 0 132 | fold_string = str(fold_idx) 133 | start_time_fold_i = time.time() 134 | logs_loss = LossHistory() 135 | print('train start time of fold{} is {}'.format( 136 | fold_idx, start_time_fold_i)) 137 | 138 | # Reading Data 139 | data_loader_16 = SeqDataLoader( 140 | data_dir, num_folds, fold_idx, classes, n_files) 141 | X_train_16, y_train_16, X_test_16, y_test_16 = data_loader_16.load_data() 142 | 143 | X_train_eog = X_train_16[:, 6:, :] 144 | X_train_16 = X_train_16[:, :6, :] 145 | 146 | X_test_eog = X_test_16[:, 6:, :] 147 | X_test_16 = X_test_16[:, :6, :] 148 | 149 | data_loader_6 = SeqDataLoader( 150 | data_dir2, num_folds, fold_idx, classes, n_files) 151 | X_train_6, y_train_6, X_test_6, y_test_6 = data_loader_6.load_data() 152 | 153 | data_loader_psd = SeqDataLoader( 154 | data_dir3, num_folds, fold_idx, classes, n_files) 155 | X_train_psd, y_train_psd, X_test_psd, y_test_psd = data_loader_psd.load_data() 156 | 157 | model_name = "model_fold{:02d}_in{:02d}of{:02d}.h5".format( 158 | fold_idx, num_folds, n_files) 159 | model = SleepPrintNet.create_SleepPrintNet( 160 | num_classes, seq_len, width, height, psd_filter_nums=args.num_filters, times=11520, Fs=128) 161 | 162 | if gpunums > 1: 163 | parallel_model = multi_gpu_model(model, gpus=gpunums) 164 | adam = keras.optimizers.Adam( 165 | lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8) 166 | parallel_model.compile( 167 | optimizer=adam, loss='sparse_categorical_crossentropy', metrics=['acc']) 168 | else: 169 | adam = keras.optimizers.Adam( 170 | lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8) 171 | model.compile( 172 | optimizer=adam, loss='sparse_categorical_crossentropy', metrics=['acc']) 173 | 174 | callbacks_list = [ 175 | callbacks.EarlyStopping( 176 | monitor='val_loss', 177 | patience=20 178 | ), 179 | logs_loss 180 | ] 181 | 182 | if save_model: 183 | callbacks_list.append(callbacks.ModelCheckpoint( 184 | filepath=model_dir + model_name, 185 | monitor='val_acc', 186 | save_best_only=True, 187 | )) 188 | 189 | if gpunums > 1: 190 | parallel_model.fit([X_train_psd, X_train_16, X_train_6, X_train_eog], y_train_6, validation_data=( 191 | [X_test_psd, X_test_16, X_test_6, X_test_eog], y_test_6), epochs=args.epoch, batch_size=args.batch_size, callbacks=callbacks_list, verbose=2, shuffle=True) 192 | else: 193 | model.fit([X_train_psd, X_train_16, X_train_6, X_train_eog], y_train_6, validation_data=([X_test_psd, X_test_16, X_test_6, 194 | X_test_eog], y_test_6), epochs=args.epoch, batch_size=args.batch_size, callbacks=callbacks_list, verbose=2, shuffle=True) 195 | 196 | del X_train_16, y_train_16, X_test_16, y_test_16, X_train_6, y_train_6, X_test_6, y_test_6, model, data_loader_16, data_loader_6, X_train_psd, y_train_psd, X_test_psd, y_test_psd, X_train_eog, X_test_eog, data_loader_psd 197 | 198 | end_time_fold_i = time.time() 199 | train_time_fold_i = end_time_fold_i - start_time_fold_i 200 | all_time += train_time_fold_i 201 | logs_loss.end_draw() 202 | acc[fold_idx] = max(logs_loss.val_acc['epoch']) 203 | print('train time of fold{} is {}'.format(fold_idx, train_time_fold_i)) 204 | 205 | for index in range(1, len(over_cm)): 206 | over_cm[0] += over_cm[index] 207 | print('train_time:', all_time) 208 | print("over_cm:") 209 | print(over_cm[0]) 210 | np.savetxt(file_dir+"over_cm.txt", acc) 211 | 212 | 213 | def main(): 214 | parser = argparse.ArgumentParser( 215 | description='SleepPrintNet - MASS-SS3 - K fold') 216 | parser.add_argument('--batch_size', type=int, default=64, metavar='N', 217 | help='batch size (default: 64)') 218 | parser.add_argument('--epoch', type=int, default=100, metavar='N', 219 | help='epoch (default: 100)') 220 | parser.add_argument('--num_fold', type=int, default=31, metavar='N', 221 | help='fold num (default:31)') 222 | 223 | parser.add_argument('--seqLen', type=int, default=5, metavar='N', 224 | help='Seq length (default: 5)') 225 | parser.add_argument('--height', type=int, default=16, metavar='N', 226 | help='Height of 2D Map (default: 16)') 227 | parser.add_argument('--width', type=int, default=16, metavar='N', 228 | help='Width of 2D Map (default: 16)') 229 | parser.add_argument('--num_filters', type=int, default=16, metavar='N', 230 | help='num_filters (default: 16)') 231 | parser.add_argument('--save_model', type=int, default=1, metavar='N', 232 | help='save_model (default: 0)') 233 | 234 | parser.add_argument('--model_dir', type=str, default='./output_model/', metavar='N', 235 | help='output dir (default: ./output_model/)') 236 | parser.add_argument('--data_dir1', type=str, default='./EEG_EOG', metavar='N', 237 | help='data_dir1 (default: ./EEG_EOG)') 238 | parser.add_argument('--data_dir2', type=str, default='./EMG', metavar='N', 239 | help='data_dir2 (default: ./EMG)') 240 | parser.add_argument('--data_dir3', type=str, default='./fre_spa', metavar='N', 241 | help='data_dir3 (default:./fre_spa)') 242 | parser.add_argument('--n_files', type=int, default=62, metavar='N', 243 | help='n_files (default: 62)') 244 | parser.add_argument('--result_dir', type=str, default='./result/', metavar='N', 245 | help='result_dir (default: ./result/)') 246 | 247 | args = parser.parse_args() 248 | 249 | print("SleepPrintNet") 250 | train(args) 251 | 252 | 253 | if __name__ == "__main__": 254 | main() 255 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.regularizers import l2 3 | import keras 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | K.set_image_data_format("channels_last") 7 | 8 | 9 | def se_slice_psd(x): 10 | return x[:, :, :, :, 0] 11 | 12 | 13 | def deepsleepnet(intput, Fs, time_filters_nums, bn_mom, version): 14 | ######### CNNs with small filter size at the first layer ######### 15 | y1 = keras.layers.Conv1D(name='conv1_small{}'.format(version), kernel_size=Fs//2, strides=Fs//16, filters=time_filters_nums, padding='same', 16 | use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(intput) 17 | y1 = keras.layers.BatchNormalization(axis=-1, momentum=bn_mom, epsilon=0.001, center=True, scale=True, 18 | beta_initializer='zeros', gamma_initializer='ones', 19 | moving_mean_initializer='zeros', moving_variance_initializer='ones', 20 | beta_regularizer=None, gamma_regularizer=None, 21 | beta_constraint=None, gamma_constraint=None)(y1) 22 | y1 = keras.layers.LeakyReLU()(y1) 23 | 24 | y1 = keras.layers.MaxPooling1D(pool_size=8, strides=8, padding='same')(y1) 25 | y1 = keras.layers.Dropout(0.5)(y1) 26 | 27 | y1 = keras.layers.Conv1D(name='conv2_small{}'.format(version), kernel_size=8, strides=1, filters=time_filters_nums*2, padding='same', 28 | kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(y1) 29 | y1 = keras.layers.BatchNormalization(axis=-1, momentum=bn_mom, epsilon=0.001, center=True, scale=True, 30 | beta_initializer='zeros', gamma_initializer='ones', 31 | moving_mean_initializer='zeros', moving_variance_initializer='ones', 32 | beta_regularizer=None, gamma_regularizer=None, 33 | beta_constraint=None, gamma_constraint=None)(y1) 34 | y1 = keras.layers.LeakyReLU()(y1) 35 | 36 | y1 = keras.layers.Conv1D(name='conv3_small{}'.format(version), kernel_size=8, strides=1, filters=time_filters_nums*2, padding='same', 37 | kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(y1) 38 | y1 = keras.layers.BatchNormalization(axis=-1, momentum=bn_mom, epsilon=0.001, center=True, scale=True, 39 | beta_initializer='zeros', gamma_initializer='ones', 40 | moving_mean_initializer='zeros', moving_variance_initializer='ones', 41 | beta_regularizer=None, gamma_regularizer=None, 42 | beta_constraint=None, gamma_constraint=None)(y1) 43 | y1 = keras.layers.LeakyReLU()(y1) 44 | 45 | y1 = keras.layers.Conv1D(name='conv4_small{}'.format(version), kernel_size=8, strides=1, filters=time_filters_nums*2, padding='same', 46 | kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(y1) 47 | y1 = keras.layers.BatchNormalization(axis=-1, momentum=bn_mom, epsilon=0.001, center=True, scale=True, 48 | beta_initializer='zeros', gamma_initializer='ones', 49 | moving_mean_initializer='zeros', moving_variance_initializer='ones', 50 | beta_regularizer=None, gamma_regularizer=None, 51 | beta_constraint=None, gamma_constraint=None)(y1) 52 | y1 = keras.layers.LeakyReLU()(y1) 53 | y1 = keras.layers.MaxPooling1D(pool_size=4, strides=4, padding='same')(y1) 54 | y1 = keras.layers.Flatten()(y1) 55 | 56 | ######### CNNs with big filter size at the first layer ######### 57 | y2 = keras.layers.Conv1D(name='conv1_big{}'.format(version), kernel_size=Fs*4, strides=Fs//2, filters=time_filters_nums, padding='same', 58 | use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(intput) 59 | y2 = keras.layers.BatchNormalization(axis=-1, momentum=bn_mom, epsilon=0.001, center=True, scale=True, 60 | beta_initializer='zeros', gamma_initializer='ones', 61 | moving_mean_initializer='zeros', moving_variance_initializer='ones', 62 | beta_regularizer=None, gamma_regularizer=None, 63 | beta_constraint=None, gamma_constraint=None)(y2) 64 | y2 = keras.layers.LeakyReLU()(y2) 65 | 66 | y2 = keras.layers.MaxPooling1D(pool_size=4, strides=4, padding='same')(y2) 67 | y2 = keras.layers.Dropout(0.5)(y2) 68 | 69 | y2 = keras.layers.Conv1D(name='conv2_big{}'.format(version), kernel_size=6, strides=1, filters=time_filters_nums*2, padding='same', 70 | kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(y2) 71 | y2 = keras.layers.BatchNormalization(axis=-1, momentum=bn_mom, epsilon=0.001, center=True, scale=True, 72 | beta_initializer='zeros', gamma_initializer='ones', 73 | moving_mean_initializer='zeros', moving_variance_initializer='ones', 74 | beta_regularizer=None, gamma_regularizer=None, 75 | beta_constraint=None, gamma_constraint=None)(y2) 76 | y2 = keras.layers.LeakyReLU()(y2) 77 | 78 | y2 = keras.layers.Conv1D(name='conv3_big{}'.format(version), kernel_size=6, strides=1, filters=time_filters_nums*2, padding='same', 79 | kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(y2) 80 | y2 = keras.layers.BatchNormalization(axis=-1, momentum=bn_mom, epsilon=0.001, center=True, scale=True, 81 | beta_initializer='zeros', gamma_initializer='ones', 82 | moving_mean_initializer='zeros', moving_variance_initializer='ones', 83 | beta_regularizer=None, gamma_regularizer=None, 84 | beta_constraint=None, gamma_constraint=None)(y2) 85 | y2 = keras.layers.LeakyReLU()(y2) 86 | 87 | y2 = keras.layers.Conv1D(name='conv4_big{}'.format(version), kernel_size=6, strides=1, filters=time_filters_nums*2, padding='same', 88 | kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(y2) 89 | y2 = keras.layers.BatchNormalization(axis=-1, momentum=bn_mom, epsilon=0.001, center=True, scale=True, 90 | beta_initializer='zeros', gamma_initializer='ones', 91 | moving_mean_initializer='zeros', moving_variance_initializer='ones', 92 | beta_regularizer=None, gamma_regularizer=None, 93 | beta_constraint=None, gamma_constraint=None)(y2) 94 | y2 = keras.layers.LeakyReLU()(y2) 95 | 96 | y2 = keras.layers.MaxPooling1D(pool_size=2, strides=2, padding='same')(y2) 97 | y2 = keras.layers.Flatten()(y2) 98 | y = keras.layers.concatenate([y1, y2], axis=-1) 99 | return y 100 | 101 | 102 | def create_SleepPrintNet( 103 | num_class, 104 | seq_len=100, 105 | width=16, 106 | height=16, 107 | use_bias=True, 108 | bn_mom=0.9, 109 | times=7680, 110 | Fs=128, 111 | time_filters_nums=64, 112 | psd_filter_nums=32 113 | ): 114 | 115 | # Begin Layers 116 | # (Samples,5,16,16,1) 117 | input_layer = keras.layers.Input( 118 | name='input_layer_psd', shape=(seq_len, width, height, 1)) 119 | input_psd = keras.layers.Lambda( 120 | se_slice_psd)(input_layer) 121 | input_psd = keras.layers.core.Permute((2, 3, 1))(input_psd) 122 | 123 | # Residual 16*16 124 | x_psd = keras.layers.Conv2D(name='conv1_middle_psd', kernel_size=(1, 1), strides=(1, 1), filters=psd_filter_nums, padding='same', 125 | use_bias=use_bias, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(input_psd) 126 | x_psd = keras.layers.BatchNormalization(axis=-1, momentum=bn_mom, epsilon=0.001, center=True, scale=True, 127 | beta_initializer='zeros', gamma_initializer='ones', 128 | moving_mean_initializer='zeros', moving_variance_initializer='ones', 129 | beta_regularizer=None, gamma_regularizer=None, 130 | beta_constraint=None, gamma_constraint=None)(x_psd) 131 | x_psd = keras.layers.ReLU()(x_psd) 132 | 133 | x = keras.layers.Conv2D(kernel_size=(3, 3), strides=(1, 1), filters=psd_filter_nums*2, padding='same', 134 | use_bias=True, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(x_psd) 135 | x = keras.layers.BatchNormalization(axis=-1, momentum=bn_mom, epsilon=0.001, center=True, scale=True, 136 | beta_initializer='zeros', gamma_initializer='ones', 137 | moving_mean_initializer='zeros', moving_variance_initializer='ones', 138 | beta_regularizer=None, gamma_regularizer=None, 139 | beta_constraint=None, gamma_constraint=None)(x) 140 | 141 | x_psd = keras.layers.Conv2D(kernel_size=(1, 1), strides=(1, 1), filters=psd_filter_nums*2, padding='same', 142 | use_bias=True, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(x_psd) 143 | x = keras.layers.add([x_psd, x]) 144 | x = keras.layers.MaxPooling2D(pool_size=( 145 | 4, 4), strides=(2, 2), padding='valid')(x) 146 | x = keras.layers.ReLU()(x) 147 | x = keras.layers.Flatten()(x) 148 | x = keras.layers.Dropout(0.5)(x) 149 | 150 | # DeepSleepNet for EEG, EMG, and EOG 151 | input_layer_time = keras.layers.Input( 152 | name='input_layer_time', shape=(6, times)) 153 | layer_time = keras.layers.Reshape((times, 6))(input_layer_time) 154 | y1 = deepsleepnet(layer_time, Fs, time_filters_nums, bn_mom, '1') 155 | 156 | input_layer_emg = keras.layers.Input( 157 | name='input_layer_emg', shape=(3, times)) 158 | layer_emg = keras.layers.Reshape((times, 3))(input_layer_emg) 159 | y2 = deepsleepnet(layer_emg, Fs, time_filters_nums//2, bn_mom, '2') 160 | 161 | input_layer_eog = keras.layers.Input( 162 | name='input_layer_time_eog', shape=(2, times)) 163 | layer_eog = keras.layers.Reshape((times, 2))(input_layer_eog) 164 | y3 = deepsleepnet(layer_eog, Fs, time_filters_nums//2, bn_mom, '3') 165 | 166 | y = keras.layers.concatenate([x, y1, y2, y3], axis=-1) 167 | 168 | y = keras.layers.Dense(128, activation='relu', 169 | kernel_regularizer=l2(0.1))(y) 170 | y = keras.layers.Dense(num_class, activation='softmax', 171 | kernel_regularizer=l2(0.1))(y) 172 | model = keras.models.Model( 173 | [input_layer, input_layer_time, input_layer_emg, input_layer_eog], y) 174 | return model 175 | --------------------------------------------------------------------------------