├── .gitignore ├── README.md ├── batch_train.sh ├── ccrrsleep ├── GHMCloss.py ├── __init__.py ├── cbam1d.py ├── data_loader.py ├── loss.py ├── model.py ├── nn.py ├── optimize.py ├── sleep_stage.py ├── trainer.py └── utils.py ├── dhedfreader.py ├── predict.py ├── prepare_physionet.py ├── summary.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CCRRSleepNet # 2 | 3 | Code for the model in the paper [**CCRRSleepNet: A Hybrid Relational Inductive Biases Network for Automatic Sleep Stage Classification on Raw Single-Channel EEG by Neng Wenpeng, Lu Jun, and Xu Lei**](https://www.mdpi.com/2076-3425/11/4/456#authors). 4 | 5 | 6 | 7 | This work has been accepted for publication in [Brain Sciences](https://www.mdpi.com/2076-3425/11/4/456#authors). 8 | 9 | 10 | ## Prepare dataset ## 11 | We evaluated our CCRRSleepNet with [Sleep-EDF](https://physionet.org/pn4/sleep-edfx/) dataset. 12 | 13 | For the [Sleep-EDF](https://physionet.org/pn4/sleep-edfx/) dataset, you can run the following scripts to download SC subjects. 14 | 15 | cd data 16 | chmod +x download_physionet.sh 17 | ./download_physionet.sh 18 | 19 | Then run the following script to extract specified EEG channels and their corresponding sleep stages. 20 | 21 | python prepare_physionet.py --data_dir data --output_dir data/eeg_fpz_cz --select_ch 'EEG Fpz-Cz' 22 | python prepare_physionet.py --data_dir data --output_dir data/eeg_pz_oz --select_ch 'EEG Pz-Oz' 23 | 24 | 25 | ## Training a model ## 26 | Run this script to train a DeepSleepNet model for the first fold of the 20-fold cross-validation. 27 | 28 | python train.py --data_dir data/eeg_fpz_cz --output_dir output --n_folds 20 --fold_idx 0 --pretrain_epochs 100 --finetune_epochs 200 --resume False 29 | 30 | You need to train a CCRRSleepNet model for every fold (i.e., `fold_idx=0...19`) before you can evaluate the performance. You can use the following script to run batch training 31 | 32 | chmod +x batch_train.sh 33 | ./batch_train.sh data/eeg_fpz_cz/ output 20 0 19 0 34 | 35 | 36 | ## Scoring sleep stages ## 37 | Run this script to determine the sleep stages for the withheld subject for each cross-validation fold. 38 | 39 | python predict.py --data_dir data/eeg_fpz_cz --model_dir output --output_dir output 40 | 41 | The output will be stored in numpy files. 42 | 43 | 44 | ## Get a summary ## 45 | Run this script to show a summary of the performance of our CCRRSleepNet compared with the state-of-the-art hand-engineering approaches. The performance metrics are overall accuracy, per-class F1-score, and macro F1-score. 46 | 47 | python summary.py --data_dir output 48 | 49 | 50 | 51 | ## Acknowledgement ## 52 | We refer to part of the code of [DeepSleepNet](https://github.com/akaraspt/deepsleepnet/) -------------------------------------------------------------------------------- /batch_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data=$1 4 | output=$2 5 | nfolds=$3 6 | start=$4 7 | end=$5 8 | gpu=$6 9 | 10 | if [[ -n "$data" ]] && [[ -n "$start" ]] && [[ -n "$end" ]] && [[ -n "$gpu" ]]; then 11 | for i in $(eval echo {$start..$end}) 12 | do 13 | CUDA_VISIBLE_DEVICES=$gpu python train.py --data_dir=$data --output_dir=$output --n_folds=$nfolds --pretrain_epochs=80 --finetune_epochs=80 --fold_idx=$i --resume=False 14 | done 15 | else 16 | echo "argument error" 17 | fi 18 | 19 | -------------------------------------------------------------------------------- /ccrrsleep/GHMCloss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class GHMCLoss: 5 | def __init__(self, bins=10, momentum=0.75): 6 | self.bins = bins 7 | self.momentum = momentum 8 | self.edges_left, self.edges_right = self.get_edges(self.bins) # edges_left: [bins, 1, 1], edges_right: [bins, 1, 1] 9 | if momentum > 0: 10 | self.acc_sum = self.get_acc_sum(self.bins) # [bins] 11 | 12 | def get_edges(self, bins): 13 | edges_left = [float(x) / bins for x in range(bins)] 14 | edges_left = tf.constant(edges_left) # [bins] 15 | edges_left = tf.expand_dims(edges_left, -1) # [bins, 1] 16 | edges_left = tf.expand_dims(edges_left, -1) # [bins, 1, 1] 17 | 18 | edges_right = [float(x) / bins for x in range(1, bins + 1)] 19 | edges_right[-1] += 1e-6 20 | edges_right = tf.constant(edges_right) # [bins] 21 | edges_right = tf.expand_dims(edges_right, -1) # [bins, 1] 22 | edges_right = tf.expand_dims(edges_right, -1) # [bins, 1, 1] 23 | return edges_left, edges_right 24 | 25 | def get_acc_sum(self, bins): 26 | acc_sum = [0.0 for _ in range(bins)] 27 | return tf.Variable(acc_sum, trainable=False) 28 | 29 | def calc(self, input, target, mask=None, is_mask=False): 30 | """ Args: 31 | input [batch_num, class_num]: 32 | The direct prediction of classification fc layer. 33 | target [batch_num, class_num]: 34 | Binary target (0 or 1) for each sample each class. The value is -1 35 | when the sample is ignored. 36 | mask [batch_num, class_num] 37 | """ 38 | edges_left, edges_right = self.edges_left, self.edges_right 39 | mmt = self.momentum 40 | # gradient length 41 | self.g = tf.abs(tf.sigmoid(input) - target) # [batch_num, class_num] 42 | g = tf.expand_dims(self.g, axis=0) # [1, batch_num, class_num] 43 | g_greater_equal_edges_left = tf.greater_equal(g, edges_left)# [bins, batch_num, class_num] 44 | g_less_edges_right = tf.less(g, edges_right)# [bins, batch_num, class_num] 45 | zero_matrix = tf.cast(tf.zeros_like(g_greater_equal_edges_left), dtype=tf.float32) # [bins, batch_num, class_num] 46 | if is_mask: 47 | mask_greater_zero = tf.greater(mask, 0) 48 | inds = tf.cast(tf.logical_and(tf.logical_and(g_greater_equal_edges_left, g_less_edges_right), 49 | mask_greater_zero), dtype=tf.float32) # [bins, batch_num, class_num] 50 | tot = tf.maximum(tf.reduce_sum(tf.cast(mask_greater_zero, dtype=tf.float32)), 1.0) 51 | else: 52 | inds = tf.cast(tf.logical_and(g_greater_equal_edges_left, g_less_edges_right), 53 | dtype=tf.float32) # [bins, batch_num, class_num] 54 | input_shape = tf.shape(input) 55 | tot = tf.maximum(tf.cast(input_shape[0] * input_shape[1], dtype=tf.float32), 1.0) 56 | num_in_bin = tf.reduce_sum(inds, axis=[1, 2]) # [bins] 57 | num_in_bin_greater_zero = tf.greater(num_in_bin, 0) # [bins] 58 | num_valid_bin = tf.reduce_sum(tf.cast(num_in_bin_greater_zero, dtype=tf.float32)) 59 | 60 | # num_in_bin = num_in_bin + 1e-12 61 | if mmt > 0: 62 | update = tf.assign(self.acc_sum, tf.where(num_in_bin_greater_zero, mmt * self.acc_sum \ 63 | + (1 - mmt) * num_in_bin, self.acc_sum)) 64 | with tf.control_dependencies([update]): 65 | self.acc_sum_tmp = tf.identity(self.acc_sum, name='updated_accsum') 66 | acc_sum = tf.expand_dims(self.acc_sum_tmp, -1) # [bins, 1] 67 | acc_sum = tf.expand_dims(acc_sum, -1) # [bins, 1, 1] 68 | acc_sum = acc_sum + zero_matrix # [bins, batch_num, class_num] 69 | weights = tf.where(tf.equal(inds, 1), tot / acc_sum, zero_matrix) 70 | weights = tf.reduce_sum(weights, axis=0) 71 | else: 72 | num_in_bin = tf.expand_dims(num_in_bin, -1) # [bins, 1] 73 | num_in_bin = tf.expand_dims(num_in_bin, -1) # [bins, 1, 1] 74 | num_in_bin = num_in_bin + zero_matrix # [bins, batch_num, class_num] 75 | weights = tf.where(tf.equal(inds, 1), tot / num_in_bin, zero_matrix) 76 | weights = tf.reduce_sum(weights, axis=0) 77 | weights = weights / num_valid_bin 78 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=input) 79 | loss = tf.reduce_sum(loss * weights) / tot 80 | return loss 81 | 82 | 83 | if __name__ == '__main__': 84 | ghm = GHMCLoss(momentum=0.75) 85 | input_1 = tf.constant([[0.05, 0.25],[0.15, 0.65]], dtype=tf.float32) # 86 | target_1 = tf.constant([[1.0, 0.0], [0.0, 1.0]], dtype=tf.float32) 87 | 88 | input_2 = tf.constant([[0.75, 0.65], [0.85, 0.05]], dtype=tf.float32) 89 | target_2 = tf.constant([[1.0, 0.0], [0.0, 0.0]], dtype=tf.float32) 90 | with tf.Session() as sess: 91 | init = tf.initialize_all_variables() 92 | sess.run(init) 93 | loss = ghm.calc(input_1, target_1) 94 | print(sess.run([loss,ghm.g,ghm.acc_sum_tmp])) 95 | loss = ghm.calc(input_2, target_2) 96 | print(sess.run([loss,ghm.g,ghm.acc_sum_tmp])) 97 | loss = ghm.calc(input_2, target_2) 98 | print(sess.run([loss,ghm.g,ghm.acc_sum_tmp])) 99 | loss = ghm.calc(input_1, target_1) 100 | print(sess.run([loss,ghm.g,ghm.acc_sum_tmp])) 101 | loss = ghm.calc(input_1, target_1) 102 | print(sess.run([loss,ghm.g,ghm.acc_sum_tmp])) 103 | 104 | # loss = ghm.calc(input_1, target_1) 105 | # print(sess.run([loss,ghm.g])) 106 | # loss = ghm.calc(input_2, target_2) 107 | # print(sess.run([loss,ghm.g])) 108 | # loss = ghm.calc(input_2, target_2) 109 | # print(sess.run([loss,ghm.g])) 110 | # loss = ghm.calc(input_1, target_1) 111 | # print(sess.run([loss,ghm.g])) 112 | # loss = ghm.calc(input_1, target_1) 113 | # print(sess.run([loss,ghm.g])) 114 | -------------------------------------------------------------------------------- /ccrrsleep/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengwp/CCRRSleepNet/6c53934d2667a3959adc561c22df4d58691899d5/ccrrsleep/__init__.py -------------------------------------------------------------------------------- /ccrrsleep/cbam1d.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def mish(x): 9 | return x * tf.math.tanh(tf.nn.softplus(x)) 10 | 11 | def se_block(residual, name, ratio=8): 12 | """Contains the implementation of Squeeze-and-Excitation(SE) block. 13 | As described in https://arxiv.org/abs/1709.01507. 14 | """ 15 | 16 | kernel_initializer = tf.contrib.layers.variance_scaling_initializer() 17 | bias_initializer = tf.constant_initializer(value=0.0) 18 | 19 | with tf.variable_scope(name): 20 | channel = residual.get_shape()[-1] 21 | # Global average pooling 22 | squeeze = tf.reduce_mean(residual, axis=[1], keepdims=True) 23 | assert squeeze.get_shape()[1:] == (1, 1, channel) 24 | excitation = tf.layers.dense(inputs=squeeze, 25 | units=channel // ratio, 26 | activation=tf.nn.relu, 27 | kernel_initializer=kernel_initializer, 28 | bias_initializer=bias_initializer, 29 | name='bottleneck_fc') 30 | assert excitation.get_shape()[1:] == (1, 1, channel // ratio) 31 | excitation = tf.layers.dense(inputs=excitation, 32 | units=channel, 33 | activation=tf.nn.sigmoid, 34 | kernel_initializer=kernel_initializer, 35 | bias_initializer=bias_initializer, 36 | name='recover_fc') 37 | assert excitation.get_shape()[1:] == (1, 1, channel) 38 | # top = tf.multiply(bottom, se, name='scale') 39 | scale = residual * excitation 40 | return scale 41 | 42 | 43 | def cbam_block(input_feature, name, ratio=8): 44 | """Contains the implementation of Convolutional Block Attention Module(CBAM) block. 45 | As described in https://arxiv.org/abs/1807.06521. 46 | """ 47 | 48 | with tf.variable_scope(name): 49 | attention_feature = channel_attention(input_feature, 'ch_at', ratio) 50 | attention_feature = spatial_attention(attention_feature, 'sp_at') 51 | # print("CBAM Hello") 52 | return attention_feature 53 | 54 | 55 | def channel_attention(input_feature, name, ratio=8): 56 | kernel_initializer = tf.contrib.layers.variance_scaling_initializer() 57 | bias_initializer = tf.constant_initializer(value=0.0) 58 | 59 | with tf.variable_scope(name): 60 | channel = input_feature.get_shape()[-1] 61 | avg_pool = tf.reduce_mean(input_feature, axis=[1], keepdims=True) 62 | 63 | assert avg_pool.get_shape()[1:] == (1, 1, channel) 64 | avg_pool = tf.layers.dense(inputs=avg_pool, 65 | units=channel // ratio, 66 | activation=mish, 67 | kernel_initializer=kernel_initializer, 68 | bias_initializer=bias_initializer, 69 | name='mlp_0', 70 | reuse=None) 71 | assert avg_pool.get_shape()[1:] == (1, 1, channel // ratio) 72 | avg_pool = tf.layers.dense(inputs=avg_pool, 73 | units=channel, 74 | kernel_initializer=kernel_initializer, 75 | bias_initializer=bias_initializer, 76 | name='mlp_1', 77 | reuse=None) 78 | assert avg_pool.get_shape()[1:] == (1, 1, channel) 79 | 80 | max_pool = tf.reduce_max(input_feature, axis=[1], keepdims=True) 81 | assert max_pool.get_shape()[1:] == (1, 1, channel) 82 | max_pool = tf.layers.dense(inputs=max_pool, 83 | units=channel // ratio, 84 | activation=mish, 85 | name='mlp_0', 86 | reuse=True) 87 | assert max_pool.get_shape()[1:] == (1, 1, channel // ratio) 88 | max_pool = tf.layers.dense(inputs=max_pool, 89 | units=channel, 90 | name='mlp_1', 91 | reuse=True) 92 | assert max_pool.get_shape()[1:] == (1, 1, channel) 93 | 94 | scale = tf.sigmoid(avg_pool + max_pool, 'sigmoid') 95 | 96 | return input_feature * scale 97 | 98 | 99 | def spatial_attention(input_feature, name, kernel_size = 7): 100 | kernel_initializer = tf.contrib.layers.variance_scaling_initializer() 101 | with tf.variable_scope(name): 102 | avg_pool = tf.reduce_mean(input_feature, axis=[3], keepdims=True) 103 | assert avg_pool.get_shape()[-1] == 1 104 | max_pool = tf.reduce_max(input_feature, axis=[3], keepdims=True) 105 | assert max_pool.get_shape()[-1] == 1 106 | concat = tf.concat([avg_pool, max_pool], 3) 107 | assert concat.get_shape()[-1] == 2 108 | 109 | concat = tf.layers.conv2d(concat, 110 | filters=1, 111 | kernel_size=[kernel_size, 1], 112 | strides=[1, 1], 113 | padding="same", 114 | activation=None, 115 | kernel_initializer=kernel_initializer, 116 | use_bias=False, 117 | name='conv') 118 | assert concat.get_shape()[-1] == 1 119 | concat = tf.sigmoid(concat, 'sigmoid') 120 | 121 | return input_feature * concat -------------------------------------------------------------------------------- /ccrrsleep/data_loader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/akaraspt/deepsleepnet 3 | Copyright 2017 Akara Supratak and Hao Dong. All rights reserved. 4 | ''' 5 | 6 | import os 7 | 8 | import numpy as np 9 | 10 | from ccrrsleep.sleep_stage import print_n_samples_each_class 11 | from ccrrsleep.utils import get_balance_class_oversample, sequence_down_sample 12 | 13 | import re 14 | from random import shuffle,seed 15 | 16 | SEED = 666 17 | use_val = False 18 | 19 | 20 | class NonSeqDataLoader(object): 21 | 22 | def __init__(self, data_dir, n_folds, fold_idx): 23 | self.data_dir = data_dir 24 | self.n_folds = n_folds 25 | self.fold_idx = fold_idx 26 | 27 | def _load_npz_file(self, npz_file): 28 | """Load data and labels from a npz file.""" 29 | with np.load(npz_file) as f: 30 | data = f["x"] 31 | labels = f["y"] 32 | sampling_rate = f["fs"] 33 | return data, labels, sampling_rate 34 | 35 | def _load_npz_list_files(self, npz_files): 36 | """Load data and labels from list of npz files.""" 37 | data = [] 38 | labels = [] 39 | fs = None 40 | for npz_f in npz_files: 41 | print("Loading {} ...".format(npz_f)) 42 | tmp_data, tmp_labels, sampling_rate = self._load_npz_file(npz_f) 43 | if fs is None: 44 | fs = sampling_rate 45 | elif fs != sampling_rate: 46 | raise Exception("Found mismatch in sampling rate.") 47 | data.append(tmp_data) 48 | labels.append(tmp_labels) 49 | data = np.vstack(data) 50 | labels = np.hstack(labels) 51 | return data, labels 52 | 53 | def load_train_data(self): 54 | allfiles = os.listdir(self.data_dir) 55 | npzfiles = [] 56 | for idx, f in enumerate(allfiles): 57 | if ".npz" in f: 58 | npzfiles.append(os.path.join(self.data_dir, f)) 59 | npzfiles.sort() 60 | name_list = [] 61 | for idx, file in enumerate(npzfiles): 62 | name_i = [] 63 | name = file[-9:-7] 64 | for idx, file in enumerate(npzfiles): 65 | if name == file[-9:-7]: 66 | name_i.append(file) 67 | if name_i not in name_list: 68 | name_list.append(name_i) 69 | list_index = list(range(len(name_list))) 70 | # Split files for training and validation sets test set 71 | test_index = np.array_split(list_index, self.n_folds) 72 | test_index = test_index[self.fold_idx] 73 | res_index = np.setdiff1d(list_index, test_index) 74 | if use_val: 75 | res_index = np.array_split(res_index, len(res_index)) 76 | res_index.sort() 77 | seed(SEED) 78 | shuffle(res_index) 79 | val_index = np.array(res_index[0:int(len(res_index) * 0.3)]).reshape(-1) 80 | train_index = np.setdiff1d(res_index, val_index) 81 | 82 | train_files = sum([name_list[i] for i in train_index], []) 83 | val_files = sum([name_list[i] for i in val_index], []) 84 | test_files = sum([name_list[i] for i in test_index], []) 85 | else: 86 | train_files = sum([name_list[i] for i in res_index], []) 87 | val_files = sum([name_list[i] for i in test_index], []) 88 | test_files = sum([name_list[i] for i in test_index], []) 89 | print('交叉验证 总数 ', self.n_folds) 90 | print('当前训练记录总数',len(train_files)) 91 | print('当前验证记录总数',len(val_files)) 92 | print('当前测试记录总数',len(test_files)) 93 | # Load a npz file 94 | print("Load training set:") 95 | data_train, label_train = self._load_npz_list_files(train_files) 96 | print(" ") 97 | 98 | print("Load validation set:") 99 | data_val, label_val = self._load_npz_list_files(val_files) 100 | print(" ") 101 | 102 | # Reshape the data to match the input of the model - conv2d 103 | data_train = np.squeeze(data_train) 104 | data_val = np.squeeze(data_val) 105 | data_train = data_train[:, :, np.newaxis, np.newaxis] 106 | data_val = data_val[:, :, np.newaxis, np.newaxis] 107 | 108 | # Casting 109 | data_train = data_train.astype(np.float32) 110 | label_train = label_train.astype(np.int32) 111 | data_val = data_val.astype(np.float32) 112 | label_val = label_val.astype(np.int32) 113 | 114 | 115 | # Use balanced-class, oversample training set 116 | # x_train, y_train = get_balance_class_oversample( 117 | # x=data_train, y=label_train 118 | # ) 119 | x_train, y_train =data_train,label_train 120 | 121 | print("Oversampled training set: {}, {}".format( 122 | x_train.shape, y_train.shape 123 | )) 124 | print_n_samples_each_class(y_train) 125 | print(" ") 126 | 127 | return x_train, y_train, data_val, label_val 128 | 129 | def load_test_data(self): 130 | allfiles = os.listdir(self.data_dir) 131 | npzfiles = [] 132 | for idx, f in enumerate(allfiles): 133 | if ".npz" in f: 134 | npzfiles.append(os.path.join(self.data_dir, f)) 135 | npzfiles.sort() 136 | name_list = [] 137 | for idx, file in enumerate(npzfiles): 138 | name_i = [] 139 | name = file[-9:-7] 140 | for idx, file in enumerate(npzfiles): 141 | if name == file[-9:-7]: 142 | name_i.append(file) 143 | if name_i not in name_list: 144 | name_list.append(name_i) 145 | list_index = list(range(len(name_list))) 146 | # Split files for training and validation sets test set 147 | test_index = np.array_split(list_index, self.n_folds) 148 | test_index = test_index[self.fold_idx] 149 | 150 | res_index = np.setdiff1d(list_index, test_index) 151 | if use_val: 152 | res_index = np.array_split(res_index, len(res_index)) 153 | res_index.sort() 154 | seed(SEED) 155 | shuffle(res_index) 156 | val_index = np.array(res_index[0:int(len(res_index) * 0.3)]).reshape(-1) 157 | train_index = np.setdiff1d(res_index, val_index) 158 | 159 | train_files = sum([name_list[i] for i in train_index], []) 160 | val_files = sum([name_list[i] for i in val_index], []) 161 | test_files = sum([name_list[i] for i in test_index], []) 162 | else: 163 | train_files = sum([name_list[i] for i in res_index], []) 164 | val_files = sum([name_list[i] for i in test_index], []) 165 | test_files = sum([name_list[i] for i in test_index], []) 166 | print('交叉验证 总数 ', self.n_folds) 167 | print('当前训练记录总数',len(train_files)) 168 | print('当前验证记录总数',len(val_files)) 169 | print('当前测试记录总数',len(test_files)) 170 | print("Load test set:") 171 | data_test, label_test = self._load_npz_list_files(test_files) 172 | print(" ") 173 | # Reshape the data to match the input of the model 174 | data_test = np.squeeze(data_test) 175 | data_test = data_test[:, :, np.newaxis, np.newaxis] 176 | 177 | # Casting 178 | data_test = data_test.astype(np.float32) 179 | label_test = label_test.astype(np.int32) 180 | 181 | return data_test, label_test 182 | 183 | 184 | class SeqDataLoader(object): 185 | 186 | def __init__(self, data_dir, n_folds, fold_idx, sequence_length = 25): 187 | self.data_dir = data_dir 188 | self.n_folds = n_folds 189 | self.fold_idx = fold_idx 190 | self.sequence_length = sequence_length 191 | 192 | def _load_npz_file(self, npz_file): 193 | """Load data and labels from a npz file.""" 194 | with np.load(npz_file) as f: 195 | data = f["x"] 196 | labels = f["y"] 197 | sampling_rate = f["fs"] 198 | return data, labels, sampling_rate 199 | 200 | def _load_npz_list_files(self, npz_files): 201 | """Load data and labels from list of npz files.""" 202 | data = [] 203 | labels = [] 204 | fs = None 205 | for npz_f in npz_files: 206 | print("Loading {} ...".format(npz_f)) 207 | tmp_data, tmp_labels, sampling_rate = self._load_npz_file(npz_f) 208 | if fs is None: 209 | fs = sampling_rate 210 | elif fs != sampling_rate: 211 | raise Exception("Found mismatch in sampling rate.") 212 | 213 | # Reshape the data to match the input of the model - conv2d 214 | tmp_data = np.squeeze(tmp_data) 215 | tmp_data = tmp_data[:, :, np.newaxis, np.newaxis] 216 | 217 | # # Reshape the data to match the input of the model - conv1d 218 | # tmp_data = tmp_data[:, :, np.newaxis] 219 | 220 | # Casting 221 | tmp_data = tmp_data.astype(np.float32) 222 | tmp_labels = tmp_labels.astype(np.int32) 223 | 224 | data.append(tmp_data) 225 | labels.append(tmp_labels) 226 | 227 | return data, labels 228 | 229 | def _load_cv_data(self, list_files): 230 | """Load sequence training and cross-validation sets.""" 231 | # Split files for training and validation sets 232 | val_files = np.array_split(list_files, self.n_folds) 233 | train_files = np.setdiff1d(list_files, val_files[self.fold_idx]) 234 | 235 | # Load a npz file 236 | print("Load training set:") 237 | data_train, label_train = self._load_npz_list_files(train_files) 238 | print(" ") 239 | print("Load validation set:") 240 | data_val, label_val = self._load_npz_list_files(val_files[self.fold_idx]) 241 | print(" ") 242 | 243 | return data_train, label_train, data_val, label_val 244 | 245 | def load_test_data(self): 246 | allfiles = os.listdir(self.data_dir) 247 | npzfiles = [] 248 | for idx, f in enumerate(allfiles): 249 | if ".npz" in f: 250 | npzfiles.append(os.path.join(self.data_dir, f)) 251 | npzfiles.sort() 252 | name_list = [] 253 | for idx, file in enumerate(npzfiles): 254 | name_i = [] 255 | name = file[-9:-7] 256 | for idx, file in enumerate(npzfiles): 257 | if name == file[-9:-7]: 258 | name_i.append(file) 259 | if name_i not in name_list: 260 | name_list.append(name_i) 261 | list_index = list(range(len(name_list))) 262 | # Split files for training and validation sets test set 263 | test_index = np.array_split(list_index, self.n_folds) 264 | test_index = test_index[self.fold_idx] 265 | 266 | res_index = np.setdiff1d(list_index, test_index) 267 | if use_val: 268 | res_index = np.array_split(res_index, len(res_index)) 269 | res_index.sort() 270 | seed(SEED) 271 | shuffle(res_index) 272 | val_index = np.array(res_index[0:int(len(res_index) * 0.3)]).reshape(-1) 273 | train_index = np.setdiff1d(res_index, val_index) 274 | 275 | train_files = sum([name_list[i] for i in train_index], []) 276 | val_files = sum([name_list[i] for i in val_index], []) 277 | test_files = sum([name_list[i] for i in test_index], []) 278 | else: 279 | train_files = sum([name_list[i] for i in res_index], []) 280 | val_files = sum([name_list[i] for i in test_index], []) 281 | test_files = sum([name_list[i] for i in test_index], []) 282 | 283 | print('交叉验证 总数 ', self.n_folds) 284 | print('当前训练记录总数',len(train_files)) 285 | print('当前验证记录总数',len(val_files)) 286 | print('当前测试记录总数',len(test_files)) 287 | print("Load test set:") 288 | data_test, label_test = self._load_npz_list_files(test_files) 289 | 290 | return data_test, label_test 291 | 292 | def load_train_data(self, n_files=None): 293 | allfiles = os.listdir(self.data_dir) 294 | npzfiles = [] 295 | for idx, f in enumerate(allfiles): 296 | if ".npz" in f: 297 | npzfiles.append(os.path.join(self.data_dir, f)) 298 | npzfiles.sort() 299 | name_list = [] 300 | for idx, file in enumerate(npzfiles): 301 | name_i = [] 302 | name = file[-9:-7] 303 | for idx, file in enumerate(npzfiles): 304 | if name == file[-9:-7]: 305 | name_i.append(file) 306 | if name_i not in name_list: 307 | name_list.append(name_i) 308 | list_index = list(range(len(name_list))) 309 | # Split files for training and validation sets test set 310 | test_index = np.array_split(list_index, self.n_folds) 311 | test_index = test_index[self.fold_idx] 312 | 313 | res_index = np.setdiff1d(list_index, test_index) 314 | if use_val: 315 | res_index = np.array_split(res_index, len(res_index)) 316 | res_index.sort() 317 | seed(SEED) 318 | shuffle(res_index) 319 | val_index = np.array(res_index[0:int(len(res_index) * 0.3)]).reshape(-1) 320 | train_index = np.setdiff1d(res_index, val_index) 321 | 322 | train_files = sum([name_list[i] for i in train_index], []) 323 | val_files = sum([name_list[i] for i in val_index], []) 324 | test_files = sum([name_list[i] for i in test_index], []) 325 | else: 326 | train_files = sum([name_list[i] for i in res_index], []) 327 | val_files = sum([name_list[i] for i in test_index], []) 328 | test_files = sum([name_list[i] for i in test_index], []) 329 | print('交叉验证折总数 ', self.n_folds) 330 | print('当前训练记录总数',len(train_files)) 331 | print('当前验证记录总数',len(val_files)) 332 | print('当前测试记录总数',len(test_files)) 333 | 334 | # Load training and validation sets 335 | print("\n========== [Fold-{}] ==========\n".format(self.fold_idx)) 336 | print("Load training set:") 337 | data_train, label_train = self._load_npz_list_files(train_files) 338 | 339 | print(" ") 340 | print("Load validation set:") 341 | data_val, label_val = self._load_npz_list_files(val_files) 342 | print(" ") 343 | 344 | print("Training set: n_subjects={}".format(len(data_train))) 345 | 346 | # -------------------------- 347 | for i in range(len(data_train)): 348 | # 二者取其一 349 | # 状态转化关系平横 350 | data_train[i], label_train[i] = sequence_down_sample(data_train[i], label_train[i], sequence_length=self.sequence_length) 351 | # 完全随机上采样 352 | # data_train[i], label_train[i] = get_balance_class_oversample(data_train[i], label_train[i]) 353 | # -------------------------- 354 | 355 | n_train_examples = 0 356 | for d in data_train: 357 | print(d.shape) 358 | n_train_examples += d.shape[0] 359 | print("Number of examples = {}".format(n_train_examples)) 360 | print_n_samples_each_class(np.hstack(label_train)) 361 | print(" ") 362 | print("Validation set: n_subjects={}".format(len(data_val))) 363 | n_valid_examples = 0 364 | for d in data_val: 365 | print(d.shape) 366 | n_valid_examples += d.shape[0] 367 | print("Number of examples = {}".format(n_valid_examples)) 368 | print_n_samples_each_class(np.hstack(label_val)) 369 | print(" ") 370 | 371 | return data_train, label_train, data_val, label_val 372 | 373 | @staticmethod 374 | def load_subject_data(data_dir, subject_idx): 375 | # Remove non-mat files, and perform ascending sort 376 | allfiles = os.listdir(data_dir) 377 | subject_files = [] 378 | for idx, f in enumerate(allfiles): 379 | if subject_idx < 10: 380 | pattern = re.compile("[a-zA-Z0-9]*0{}[1-9]E0\.npz$".format(subject_idx)) 381 | else: 382 | pattern = re.compile("[a-zA-Z0-9]*{}[1-9]E0\.npz$".format(subject_idx)) 383 | if pattern.match(f): 384 | subject_files.append(os.path.join(data_dir, f)) 385 | 386 | # Files for validation sets 387 | if len(subject_files) == 0 or len(subject_files) > 2: 388 | raise Exception("Invalid file pattern") 389 | 390 | def load_npz_file(npz_file): 391 | """Load data and labels from a npz file.""" 392 | with np.load(npz_file) as f: 393 | data = f["x"] 394 | labels = f["y"] 395 | sampling_rate = f["fs"] 396 | return data, labels, sampling_rate 397 | 398 | def load_npz_list_files(npz_files): 399 | """Load data and labels from list of npz files.""" 400 | data = [] 401 | labels = [] 402 | fs = None 403 | for npz_f in npz_files: 404 | print("Loading {} ...".format(npz_f)) 405 | tmp_data, tmp_labels, sampling_rate = load_npz_file(npz_f) 406 | if fs is None: 407 | fs = sampling_rate 408 | elif fs != sampling_rate: 409 | raise Exception("Found mismatch in sampling rate.") 410 | 411 | # Reshape the data to match the input of the model - conv2d 412 | tmp_data = np.squeeze(tmp_data) 413 | tmp_data = tmp_data[:, :, np.newaxis, np.newaxis] 414 | 415 | # # Reshape the data to match the input of the model - conv1d 416 | # tmp_data = tmp_data[:, :, np.newaxis] 417 | 418 | # Casting 419 | tmp_data = tmp_data.astype(np.float32) 420 | tmp_labels = tmp_labels.astype(np.int32) 421 | 422 | data.append(tmp_data) 423 | labels.append(tmp_labels) 424 | 425 | return data, labels 426 | 427 | print("Load data from: {}".format(subject_files)) 428 | data, labels = load_npz_list_files(subject_files) 429 | 430 | return data, labels 431 | -------------------------------------------------------------------------------- /ccrrsleep/loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class GHMCLoss: 4 | def __init__(self, bins=10, momentum=0.75): 5 | self.bins = bins 6 | self.momentum = momentum 7 | self.edges_left, self.edges_right = self.get_edges(self.bins) # edges_left: [bins, 1, 1], edges_right: [bins, 1, 1] 8 | if momentum > 0: 9 | self.acc_sum = self.get_acc_sum(self.bins) # [bins] 10 | 11 | def get_edges(self, bins): 12 | edges_left = [float(x) / bins for x in range(bins)] 13 | edges_left = tf.constant(edges_left) # [bins] 14 | edges_left = tf.expand_dims(edges_left, -1) # [bins, 1] 15 | edges_left = tf.expand_dims(edges_left, -1) # [bins, 1, 1] 16 | 17 | edges_right = [float(x) / bins for x in range(1, bins + 1)] 18 | edges_right[-1] += 1e-6 19 | edges_right = tf.constant(edges_right) # [bins] 20 | edges_right = tf.expand_dims(edges_right, -1) # [bins, 1] 21 | edges_right = tf.expand_dims(edges_right, -1) # [bins, 1, 1] 22 | return edges_left, edges_right 23 | 24 | def get_acc_sum(self, bins): 25 | acc_sum = [0.0 for _ in range(bins)] 26 | return tf.Variable(acc_sum, trainable=False) 27 | 28 | def calc(self, target, input, mask=None, is_mask=False): 29 | """ Args: 30 | input [batch_num, class_num]: 31 | The direct prediction of classification fc layer. 32 | target [batch_num, class_num]: 33 | Binary target (0 or 1) for each sample each class. The value is -1 34 | when the sample is ignored. 35 | mask [batch_num, class_num] 36 | """ 37 | edges_left, edges_right = self.edges_left, self.edges_right 38 | mmt = self.momentum 39 | # gradient length 40 | self.g = tf.abs(tf.sigmoid(input) - target) # [batch_num, class_num] 41 | g = tf.expand_dims(self.g, axis=0) # [1, batch_num, class_num] 42 | g_greater_equal_edges_left = tf.greater_equal(g, edges_left)# [bins, batch_num, class_num] 43 | g_less_edges_right = tf.less(g, edges_right)# [bins, batch_num, class_num] 44 | zero_matrix = tf.cast(tf.zeros_like(g_greater_equal_edges_left), dtype=tf.float32) # [bins, batch_num, class_num] 45 | if is_mask: 46 | mask_greater_zero = tf.greater(mask, 0) 47 | inds = tf.cast(tf.logical_and(tf.logical_and(g_greater_equal_edges_left, g_less_edges_right), 48 | mask_greater_zero), dtype=tf.float32) # [bins, batch_num, class_num] 49 | tot = tf.maximum(tf.reduce_sum(tf.cast(mask_greater_zero, dtype=tf.float32)), 1.0) 50 | else: 51 | inds = tf.cast(tf.logical_and(g_greater_equal_edges_left, g_less_edges_right), 52 | dtype=tf.float32) # [bins, batch_num, class_num] 53 | input_shape = tf.shape(input) 54 | tot = tf.maximum(tf.cast(input_shape[0] * input_shape[1], dtype=tf.float32), 1.0) 55 | num_in_bin = tf.reduce_sum(inds, axis=[1, 2]) # [bins] 56 | num_in_bin_greater_zero = tf.greater(num_in_bin, 0) # [bins] 57 | num_valid_bin = tf.reduce_sum(tf.cast(num_in_bin_greater_zero, dtype=tf.float32)) 58 | 59 | # num_in_bin = num_in_bin + 1e-12 60 | if mmt > 0: 61 | update = tf.assign(self.acc_sum, tf.where(num_in_bin_greater_zero, mmt * self.acc_sum \ 62 | + (1 - mmt) * num_in_bin, self.acc_sum)) 63 | with tf.control_dependencies([update]): 64 | self.acc_sum_tmp = tf.identity(self.acc_sum, name='updated_accsum') 65 | acc_sum = tf.expand_dims(self.acc_sum_tmp, -1) # [bins, 1] 66 | acc_sum = tf.expand_dims(acc_sum, -1) # [bins, 1, 1] 67 | acc_sum = acc_sum + zero_matrix # [bins, batch_num, class_num] 68 | weights = tf.where(tf.equal(inds, 1), tot / acc_sum, zero_matrix) 69 | weights = tf.reduce_sum(weights, axis=0) 70 | else: 71 | num_in_bin = tf.expand_dims(num_in_bin, -1) # [bins, 1] 72 | num_in_bin = tf.expand_dims(num_in_bin, -1) # [bins, 1, 1] 73 | num_in_bin = num_in_bin + zero_matrix # [bins, batch_num, class_num] 74 | weights = tf.where(tf.equal(inds, 1), tot / num_in_bin, zero_matrix) 75 | weights = tf.reduce_sum(weights, axis=0) 76 | weights = weights / num_valid_bin 77 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=input) 78 | loss = tf.reduce_sum(loss * weights) / tot 79 | return loss 80 | 81 | 82 | def focal_loss(labels, logits, gamma=2.0, alpha=0.25): 83 | logits = tf.nn.softmax(logits, dim=-1) 84 | epsilon = 1.e-9 85 | labels = tf.to_int64(labels) 86 | labels = tf.convert_to_tensor(labels, tf.int64) 87 | logits = tf.convert_to_tensor(logits, tf.float32) 88 | num_cls = logits.shape[1] 89 | 90 | model_out = tf.add(logits, epsilon) 91 | onehot_labels = tf.one_hot(labels, num_cls) 92 | ce = tf.multiply(onehot_labels, -tf.log(model_out)) 93 | weight = tf.multiply(onehot_labels, tf.pow(tf.subtract(1., model_out), gamma)) 94 | fl = tf.multiply(alpha, tf.multiply(weight, ce)) 95 | # reduced_fl = tf.reduce_max(fl, axis=1) 96 | reduced_fl = tf.reduce_sum(fl, axis=1) # same as reduce_max 97 | return reduced_fl 98 | 99 | 100 | def ghmc_loss(labels, logits,momentum=0.75): 101 | logits = tf.nn.softmax(logits, dim=-1) 102 | labels = tf.to_int64(labels) 103 | labels = tf.convert_to_tensor(labels, tf.int64) 104 | logits = tf.convert_to_tensor(logits, tf.float32) 105 | num_cls = logits.shape[1] 106 | onehot_labels = tf.one_hot(labels, num_cls) 107 | ghm = GHMCLoss(momentum=momentum) 108 | return ghm.calc(onehot_labels, logits) 109 | -------------------------------------------------------------------------------- /ccrrsleep/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from ccrrsleep.nn import * 4 | from ccrrsleep.loss import * 5 | 6 | from ccrrsleep.cbam1d import * 7 | 8 | class CCRRFeatureNet(object): 9 | 10 | def __init__( 11 | self, 12 | batch_size, 13 | input_dims, 14 | n_classes, 15 | is_train, 16 | reuse_params, 17 | use_dropout, 18 | name="ccrrfeaturenet" 19 | ): 20 | self.batch_size = batch_size 21 | self.input_dims = input_dims 22 | self.n_classes = n_classes 23 | self.is_train = is_train 24 | self.reuse_params = reuse_params 25 | self.use_dropout = use_dropout 26 | self.name = name 27 | 28 | self.activations = [] 29 | self.layer_idx = 1 30 | self.monitor_vars = [] 31 | 32 | def _build_placeholder(self): 33 | # Input 34 | name = "x_train" if self.is_train else "x_valid" 35 | self.input_var = tf.compat.v1.placeholder( 36 | tf.float32, 37 | shape=[self.batch_size, self.input_dims, 1, 1], 38 | name=name + "_inputs" 39 | ) 40 | # Target 41 | self.target_var = tf.compat.v1.placeholder( 42 | tf.int32, 43 | shape=[self.batch_size, ], 44 | name=name + "_targets" 45 | ) 46 | 47 | def _bn_activate_layer(self, input_var): 48 | name = "l{}_bn_activate".format(self.layer_idx) 49 | with tf.compat.v1.variable_scope(name) as scope: 50 | output = batch_norm_new(name="bn", input_var=input_var, is_train=self.is_train) 51 | # output = tf.nn.relu(output, name="relu") 52 | # output = tf.nn.swish(output, name="swish") 53 | output = mish(output) 54 | self.activations.append((name, output)) 55 | self.layer_idx += 1 56 | return output 57 | 58 | def _conv1d_layer(self, input_var, filter_size, n_filters, stride=1, dilations=1, padding="SAME", wd=0.0): 59 | input_shape = input_var.get_shape() 60 | n_batches = input_shape[0].value 61 | input_dims = input_shape[1].value 62 | n_in_filters = input_shape[3].value 63 | name = "l{}_conv".format(self.layer_idx) 64 | with tf.compat.v1.variable_scope(name) as scope: 65 | output = conv_1d(name="conv1d", input_var=input_var, filter_shape=[filter_size, 1, n_in_filters, n_filters], 66 | stride=stride, padding=padding, dilations=dilations, bias=None, wd=wd) 67 | output = self._bn_activate_layer(output) 68 | self.activations.append((name, output)) 69 | self.layer_idx += 1 70 | return output 71 | 72 | def _mcb_layer(self, input_var, n_filters=64): 73 | name = "l{}_mcb_conv".format(self.layer_idx) 74 | with tf.compat.v1.variable_scope(name) as scope: 75 | output_x = self._conv1d_layer(input_var=input_var, filter_size=1, n_filters=n_filters, dilations=1) 76 | 77 | name = "l{}_res_conv1".format(self.layer_idx) 78 | with tf.compat.v1.variable_scope(name) as scope: 79 | output = self._conv1d_layer(input_var=input_var, filter_size=1, n_filters=n_filters, dilations=1) 80 | output_res1 = self._conv1d_layer(input_var=output, filter_size=3, n_filters=n_filters, dilations=1) 81 | 82 | name = "l{}_res_conv2".format(self.layer_idx) 83 | with tf.compat.v1.variable_scope(name) as scope: 84 | output = self._conv1d_layer(input_var=input_var, filter_size=1, n_filters=n_filters, dilations=1) 85 | output = self._conv1d_layer(input_var=output, filter_size=3, n_filters=n_filters, dilations=1) 86 | output_res2 = self._conv1d_layer(input_var=output, filter_size=3, n_filters=n_filters, dilations=3) 87 | 88 | name = "l{}_res_conv3".format(self.layer_idx) 89 | with tf.compat.v1.variable_scope(name) as scope: 90 | output = self._conv1d_layer(input_var=input_var, filter_size=1, n_filters=n_filters, dilations=1) 91 | output = self._conv1d_layer(input_var=output, filter_size=3, n_filters=n_filters, dilations=1) 92 | output = self._conv1d_layer(input_var=output, filter_size=3, n_filters=n_filters, dilations=1) 93 | output_res3 = self._conv1d_layer(input_var=output, filter_size=3, n_filters=n_filters, dilations=5) 94 | 95 | output = self._concat_layer([output_x, output_res1, output_res2, output_res3], axis=-1) 96 | output = self._bn_activate_layer(output) 97 | self.activations.append((name, output)) 98 | self.layer_idx += 1 99 | return output 100 | 101 | def _max_pool_1d_layer(self, input_var, pool_size=2,stride=None): 102 | name = "l{}_pool".format(self.layer_idx) 103 | if not stride: 104 | stride = pool_size 105 | output = max_pool_1d(name=name, input_var=input_var, pool_size=pool_size, stride=stride) 106 | self.activations.append((name, output)) 107 | self.layer_idx += 1 108 | return output 109 | 110 | def _avg_pool_1d_layer(self, input_var, pool_size=2, stride=None): 111 | name = "l{}_pool".format(self.layer_idx) 112 | if not stride: 113 | stride = pool_size 114 | output = avg_pool_1d(name=name, input_var=input_var, pool_size=pool_size, stride=stride) 115 | self.activations.append((name, output)) 116 | self.layer_idx += 1 117 | return output 118 | 119 | def _global_avg_pool_1d_layer(self, input_var): 120 | input_shape = input_var.get_shape() 121 | pool_size = input_shape[1].value 122 | name = "l{}_global_avg_pool_1d".format(self.layer_idx) 123 | output = avg_pool_1d(name=name, input_var=input_var, pool_size=pool_size, stride=pool_size) 124 | self.activations.append((name, output)) 125 | self.layer_idx += 1 126 | return output 127 | 128 | def _dropout_layer(self, input_var, keep_prob=0.5): 129 | if self.use_dropout: 130 | name = "l{}_dropout".format(self.layer_idx) 131 | with tf.compat.v1.variable_scope(name) as scope: 132 | if self.is_train: 133 | output = tf.nn.dropout(input_var, keep_prob=keep_prob, name=name) 134 | else: 135 | output = tf.nn.dropout(input_var, keep_prob=1.0, name=name) 136 | self.activations.append((name, output)) 137 | else: 138 | output = input_var 139 | self.layer_idx += 1 140 | return output 141 | 142 | def _attention_layer(self, input_var, mode='cbam', ratio=4, kernel_size=15): 143 | name = "l{}_attention".format(self.layer_idx) 144 | with tf.compat.v1.variable_scope(name) as scope: 145 | if mode =='channel_attention': 146 | output = channel_attention(input_var, name=name, ratio=ratio) 147 | elif mode =='spatial_attention': 148 | output = spatial_attention(input_var, name=name, kernel_size=kernel_size) 149 | else: 150 | output = cbam_block(input_var, name=name, ratio=4) 151 | # output = se_block(input_var, name=name,ratio=4) 152 | output = tf.add(input_var,output) 153 | self.activations.append((name, output)) 154 | self.layer_idx += 1 155 | return output 156 | 157 | def _flatten_layer(self, input_var): 158 | name = "l{}_flat".format(self.layer_idx) 159 | output = flatten(name=name, input_var=input_var) 160 | self.activations.append((name, output)) 161 | self.layer_idx += 1 162 | return output 163 | 164 | def _concat_layer(self, values, axis=1): 165 | name = "l{}_concat".format(self.layer_idx) 166 | output = tf.concat(axis=axis, values=values, name=name) 167 | self.activations.append((name, output)) 168 | self.layer_idx += 1 169 | return output 170 | 171 | def _add_layer(self, values): 172 | name = "l{}_add".format(self.layer_idx) 173 | output = tf.add_n(values, name=name) 174 | self.activations.append((name, output)) 175 | self.layer_idx += 1 176 | return output 177 | 178 | def _fc_layer(self, input_var, n_hiddens, bias=0.0, wd=0.): 179 | name = "l{}_linear".format(self.layer_idx) 180 | output = fc(name=name, input_var=input_var, n_hiddens=n_hiddens, bias=bias, wd=wd) 181 | # output = self._bn_activate_layer(output) 182 | self.activations.append((name, output)) 183 | self.layer_idx += 1 184 | return output 185 | 186 | def _fragment_gru(self, input_var, gru_layers=1, hidden_size=512, use_dropout=False,use_attention=False): 187 | name = "l{}_bi_gru".format(self.layer_idx) 188 | with tf.compat.v1.variable_scope(name) as scope: 189 | def gru_cell(): 190 | cell = tf.compat.v1.nn.rnn_cell.GRUCell(hidden_size) 191 | if use_dropout: 192 | keep_prob = 0.5 if self.is_train else 1.0 193 | cell = tf.compat.v1.nn.rnn_cell.DropoutWrapper( 194 | cell, 195 | output_keep_prob=keep_prob 196 | ) 197 | return cell 198 | 199 | fw_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([gru_cell() for _ in range(gru_layers)], 200 | state_is_tuple=True) 201 | bw_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([gru_cell() for _ in range(gru_layers)], 202 | state_is_tuple=True) 203 | # Feedforward to MultiRNNCell 204 | list_rnn_inputs = tf.unstack(input_var, axis=1) 205 | # outputs, fw_state, bw_state = tf.nn.bidirectional_rnn( 206 | outputs, fw_state, bw_state = tf.compat.v1.nn.static_bidirectional_rnn( 207 | cell_fw=fw_cell, 208 | cell_bw=bw_cell, 209 | inputs=list_rnn_inputs, 210 | dtype=tf.float32 211 | ) 212 | if use_attention: 213 | T = len(outputs) 214 | network = tf.reshape(tf.concat(axis=1, values=outputs), [-1, hidden_size * 2,1,T]) 215 | network = self._attention_layer(network, mode='channel_attention') 216 | network = tf.reduce_mean(network,axis=-1, keepdims=False) 217 | else: 218 | network = outputs[-1] 219 | self.activations.append((name, network)) 220 | self.layer_idx += 1 221 | return network 222 | 223 | def _fragment_cnn(self, input_var): 224 | network_25 = self._conv1d_layer(input_var=input_var, filter_size=25, n_filters=64, stride=6, wd=1e-3) 225 | network_25 = self._max_pool_1d_layer(network_25, pool_size=5) 226 | 227 | network_100 = self._conv1d_layer(input_var=input_var, filter_size=100, n_filters=64, stride=15, wd=1e-3) 228 | network_100 = self._max_pool_1d_layer(network_100, pool_size=2) 229 | 230 | network = self._concat_layer([network_25, network_100], axis=-1) 231 | # network = self._attention_layer(network, mode='cbam') 232 | # # # Dropout 233 | network = self._dropout_layer(network, keep_prob=0.5) 234 | 235 | # 300 236 | # Convolution 237 | network = self._mcb_layer(network, n_filters=64) 238 | print(network.shape) 239 | # ssee.run(network) 240 | # plt.plot(network) 241 | # network = self._dropout_layer(network) 242 | # network = self._attention_layer(network, mode='cbam') 243 | # -------------------------------------- 244 | # network_res = self._global_avg_pool_1d_layer(network) 245 | # network_res = self._conv1d_layer(input_var=network_res, filter_size=1, n_filters=1024, stride=1) 246 | network_res = self._conv1d_layer(input_var=network, filter_size=1, n_filters=1024, stride=1) 247 | network_res = self._global_avg_pool_1d_layer(network_res) 248 | # # network_res = self._dropout_layer(network_res) 249 | network_res = self._flatten_layer(network_res) 250 | # -------------------------------------- 251 | network_cnn = self._conv1d_layer(input_var=network, filter_size=1, n_filters=256, stride=1) 252 | network_cnn = self._max_pool_1d_layer(network_cnn, 10) 253 | # network_cnn = self._conv1d_layer(input_var=network_cnn, filter_size=3, n_filters=256, stride=1) 254 | # network_cnn = self._max_pool_1d_layer(network_cnn, 2) 255 | # network_cnn = self._mcb_layer(network_cnn, n_filters=64) 256 | # network_cnn = self._max_pool_1d_layer(network_cnn, 3) 257 | network_cnn = self._conv1d_layer(input_var=network_cnn, filter_size=3, n_filters=1024, stride=1) 258 | # network_cnn = self._attention_layer(network_cnn, mode ='channel_attention') 259 | network_cnn = self._global_avg_pool_1d_layer(network_cnn) 260 | network_cnn = self._flatten_layer(network_cnn) 261 | # -------------------------------------- 262 | network_rnn = self._conv1d_layer(input_var=network, filter_size=1, n_filters=256, stride=1) 263 | network_rnn = self._avg_pool_1d_layer(network_rnn, 10) 264 | network_rnn_shape = network_rnn.get_shape() 265 | network_rnn_T = network_rnn_shape[1].value 266 | network_rnn_D = network_rnn_shape[3].value 267 | network_rnn = tf.reshape(network_rnn, [-1, network_rnn_T, network_rnn_D]) 268 | network_rnn = self._fragment_gru(network_rnn) 269 | network_rnn = self._flatten_layer(network_rnn) 270 | 271 | network = self._concat_layer([network_cnn, network_rnn, network_res], axis=-1) 272 | # network = self._concat_layer([network_cnn, network_rnn], axis=-1) 273 | # network = self._add_layer([network,network_res]) 274 | # network = self._dropout_layer(network,keep_prob=0.5) 275 | return network 276 | 277 | def build_model(self, input_var): 278 | network = self._fragment_cnn(input_var) 279 | return network 280 | 281 | def init_ops(self): 282 | self._build_placeholder() 283 | 284 | # Get loss and prediction operations 285 | with tf.compat.v1.variable_scope(self.name) as scope: 286 | # Reuse variables for validation 287 | if self.reuse_params: 288 | scope.reuse_variables() 289 | 290 | # Build model 291 | network = self.build_model(input_var=self.input_var) 292 | 293 | # network = self._fc_layer(network, self.n_classes) 294 | # Softmax linear 295 | name = "l{}_softmax_linear".format(self.layer_idx) 296 | network = fc(name=name, input_var=network, n_hiddens=self.n_classes, bias=0.0, wd=0) 297 | self.activations.append((name, network)) 298 | self.layer_idx += 1 299 | 300 | # Outputs of softmax linear are logits 301 | self.logits = network 302 | 303 | ######### Compute loss ######### 304 | # Cross-entropy loss 305 | # loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 306 | # logits=self.logits, 307 | # labels=self.target_var, 308 | # name="sparse_softmax_cross_entropy_with_logits" 309 | # ) 310 | loss = focal_loss(self.target_var, self.logits) 311 | # loss = ghmc_loss(self.target_var, self.logits) 312 | loss = tf.reduce_mean(loss, name="cross_entropy") 313 | 314 | # Regularization loss 315 | regular_loss = tf.add_n( 316 | tf.compat.v1.get_collection("losses", scope=scope.name + "\/"), 317 | name="regular_loss" 318 | ) 319 | 320 | # Total loss 321 | self.loss_op = tf.add(loss, regular_loss) 322 | 323 | # Predictions 324 | self.pred_op = tf.argmax(self.logits, 1) 325 | 326 | 327 | class CCRRSleepNet(CCRRFeatureNet): 328 | 329 | def __init__( 330 | self, 331 | batch_size, 332 | input_dims, 333 | n_classes, 334 | seq_length, 335 | n_rnn_layers, 336 | return_last, 337 | is_train, 338 | reuse_params, 339 | use_dropout_feature, 340 | use_dropout_sequence, 341 | name="ccrrsleepnet" 342 | ): 343 | super(self.__class__, self).__init__( 344 | batch_size=batch_size, 345 | input_dims=input_dims, 346 | n_classes=n_classes, 347 | is_train=is_train, 348 | reuse_params=reuse_params, 349 | use_dropout=use_dropout_feature, 350 | name=name 351 | ) 352 | 353 | self.seq_length = seq_length 354 | self.n_rnn_layers = n_rnn_layers 355 | self.return_last = return_last 356 | 357 | self.use_dropout_sequence = use_dropout_sequence 358 | 359 | def _build_placeholder(self): 360 | # Input 361 | name = "x_train" if self.is_train else "x_valid" 362 | self.input_var = tf.compat.v1.placeholder( 363 | tf.float32, 364 | shape=[self.batch_size * self.seq_length, self.input_dims, 1, 1], 365 | name=name + "_inputs" 366 | ) 367 | # Target 368 | self.target_var = tf.compat.v1.placeholder( 369 | tf.int32, 370 | shape=[self.batch_size * self.seq_length, ], 371 | name=name + "_targets" 372 | ) 373 | 374 | def build_model(self, input_var): 375 | # Create a network with superclass method 376 | network = super(self.__class__, self).build_model( 377 | input_var=self.input_var 378 | ) 379 | 380 | # Residual (or shortcut) connection 381 | output_conns = [] 382 | 383 | # Fully-connected to select some part of the output to add with the output from bi-directional LSTM 384 | name = "l{}_fc".format(self.layer_idx) 385 | with tf.compat.v1.variable_scope(name) as scope: 386 | output = fc(name="fc", input_var=network, n_hiddens=1024, bias=None, wd=0) 387 | output = batch_norm_new(name="bn", input_var=output, is_train=self.is_train) 388 | output = mish(output) 389 | self.activations.append((name, output)) 390 | self.layer_idx += 1 391 | output_conns.append(output) 392 | 393 | # output_conns.append(network) 394 | 395 | ###################################################################### 396 | 397 | # Reshape the input from (batch_size * seq_length, input_dim) to 398 | # (batch_size, seq_length, input_dim) 399 | name = "l{}_reshape_seq".format(self.layer_idx) 400 | input_dim = network.get_shape()[-1].value 401 | seq_input = tf.reshape(network, 402 | shape=[-1, self.seq_length, input_dim], 403 | name=name) 404 | assert self.batch_size == seq_input.get_shape()[0].value 405 | self.activations.append((name, seq_input)) 406 | self.layer_idx += 1 407 | 408 | # Bidirectional LSTM network 409 | name = "l{}_bi_lstm".format(self.layer_idx) 410 | hidden_size = 512 # will output 1024 (512 forward, 512 backward) 411 | with tf.compat.v1.variable_scope(name) as scope: 412 | 413 | def lstm_cell(): 414 | 415 | cell = tf.compat.v1.nn.rnn_cell.LSTMCell(hidden_size, 416 | use_peepholes=True, 417 | state_is_tuple=True, 418 | reuse=tf.compat.v1.get_variable_scope().reuse) 419 | if self.use_dropout_sequence: 420 | keep_prob = 0.5 if self.is_train else 1.0 421 | cell = tf.compat.v1.nn.rnn_cell.DropoutWrapper( 422 | cell, 423 | output_keep_prob=keep_prob 424 | ) 425 | 426 | return cell 427 | 428 | fw_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(self.n_rnn_layers)], 429 | state_is_tuple=True) 430 | bw_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(self.n_rnn_layers)], 431 | state_is_tuple=True) 432 | 433 | # Initial state of RNN 434 | self.fw_initial_state = fw_cell.zero_state(self.batch_size, tf.float32) 435 | self.bw_initial_state = bw_cell.zero_state(self.batch_size, tf.float32) 436 | 437 | # Feedforward to MultiRNNCell 438 | list_rnn_inputs = tf.unstack(seq_input, axis=1) 439 | # outputs, fw_state, bw_state = tf.nn.bidirectional_rnn( 440 | outputs, fw_state, bw_state = tf.compat.v1.nn.static_bidirectional_rnn( 441 | cell_fw=fw_cell, 442 | cell_bw=bw_cell, 443 | inputs=list_rnn_inputs, 444 | initial_state_fw=self.fw_initial_state, 445 | initial_state_bw=self.bw_initial_state 446 | ) 447 | 448 | if self.return_last: 449 | network = outputs[-1] 450 | else: 451 | network = tf.reshape(tf.concat(axis=1, values=outputs), [-1, hidden_size * 2], 452 | name=name) 453 | self.activations.append((name, network)) 454 | self.layer_idx += 1 455 | 456 | self.fw_final_state = fw_state 457 | self.bw_final_state = bw_state 458 | 459 | # Append output 460 | output_conns.append(network) 461 | 462 | ###################################################################### 463 | 464 | # Concat 465 | # network = tf.concat(output_conns, axis=1) 466 | 467 | # Add 468 | name = "l{}_add".format(self.layer_idx) 469 | network = tf.add_n(output_conns, name=name) 470 | self.activations.append((name, network)) 471 | self.layer_idx += 1 472 | 473 | # Dropout 474 | if self.use_dropout_sequence: 475 | name = "l{}_dropout".format(self.layer_idx) 476 | if self.is_train: 477 | network = tf.nn.dropout(network, keep_prob=0.5, name=name) 478 | else: 479 | network = tf.nn.dropout(network, keep_prob=1.0, name=name) 480 | self.activations.append((name, network)) 481 | self.layer_idx += 1 482 | 483 | return network 484 | 485 | def init_ops(self): 486 | self._build_placeholder() 487 | 488 | # Get loss and prediction operations 489 | with tf.compat.v1.variable_scope(self.name) as scope: 490 | # Reuse variables for validation 491 | if self.reuse_params: 492 | scope.reuse_variables() 493 | 494 | # Build model 495 | network = self.build_model(input_var=self.input_var) 496 | 497 | # Softmax linear 498 | name = "l{}_softmax_linear".format(self.layer_idx) 499 | network = fc(name=name, input_var=network, n_hiddens=self.n_classes, bias=0.0, wd=0) 500 | self.activations.append((name, network)) 501 | self.layer_idx += 1 502 | 503 | # Outputs of softmax linear are logits 504 | self.logits = network 505 | # Weighted cross-entropy loss for a sequence of logits (per example) 506 | loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example( 507 | [self.logits], 508 | [self.target_var], 509 | [tf.ones([self.batch_size * self.seq_length])], 510 | softmax_loss_function=focal_loss, 511 | # softmax_loss_function=ghmc_loss, 512 | name="sequence_loss_by_example" 513 | ) 514 | loss = tf.reduce_sum(loss) / self.batch_size 515 | 516 | # Regularization loss 517 | regular_loss = tf.add_n( 518 | tf.compat.v1.get_collection("losses", scope=scope.name + "\/"), 519 | name="regular_loss" 520 | ) 521 | 522 | # Total loss 523 | self.loss_op = tf.add(loss, regular_loss) 524 | 525 | # Predictions 526 | self.pred_op = tf.argmax(self.logits, 1) 527 | -------------------------------------------------------------------------------- /ccrrsleep/nn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/akaraspt/deepsleepnet 3 | Copyright 2017 Akara Supratak and Hao Dong. All rights reserved. 4 | ''' 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | from tensorflow.python.framework import ops 10 | from tensorflow.python.training import moving_averages 11 | 12 | 13 | def mish(x): 14 | return x * tf.math.tanh(tf.nn.softplus(x)) 15 | 16 | 17 | def _create_variable(name, shape, initializer): 18 | var = tf.compat.v1.get_variable(name, shape, initializer=initializer) 19 | return var 20 | 21 | 22 | def variable_with_weight_decay(name, shape, wd=None): 23 | # Get the number of input and output parameters 24 | if len(shape) == 2: 25 | fan_in = shape[0] 26 | fan_out = shape[1] 27 | elif len(shape) == 4: 28 | receptive_field_size = np.prod(shape[:2]) 29 | fan_in = shape[-2] * receptive_field_size 30 | fan_out = shape[-1] * receptive_field_size 31 | else: 32 | # no specific assumptions 33 | fan_in = np.sqrt(np.prod(shape)) 34 | fan_out = np.sqrt(np.prod(shape)) 35 | 36 | # He et al. 2015 - http://arxiv.org/abs/1502.01852 37 | stddev = np.sqrt(2.0 / fan_in) 38 | initializer = tf.truncated_normal_initializer(stddev=stddev) 39 | 40 | # # Xavier 41 | # initializer = tf.contrib.layers.xavier_initializer() 42 | 43 | # Create or get the existing variable 44 | var = _create_variable( 45 | name, 46 | shape, 47 | initializer 48 | ) 49 | 50 | # L2 weight decay 51 | if wd is not None: 52 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name="weight_loss") 53 | tf.compat.v1.add_to_collection("losses", weight_decay) 54 | 55 | return var 56 | 57 | 58 | def conv_1d(name, input_var, filter_shape, stride, dilations=1, padding="SAME", 59 | bias=None, wd=None): 60 | with tf.compat.v1.variable_scope(name) as scope: 61 | # Trainable parameters 62 | kernel = variable_with_weight_decay( 63 | "weights", 64 | shape=filter_shape, 65 | wd=wd 66 | ) 67 | 68 | # Convolution 69 | output_var = tf.nn.conv2d( 70 | input_var, 71 | kernel, 72 | [1, stride, 1, 1], 73 | padding=padding, 74 | dilations=[1, dilations, 1, 1] 75 | ) 76 | 77 | # Bias 78 | if bias is not None: 79 | biases = _create_variable( 80 | "biases", 81 | [filter_shape[-1]], 82 | tf.constant_initializer(bias) 83 | ) 84 | output_var = tf.nn.bias_add(output_var, biases) 85 | 86 | return output_var 87 | 88 | def depthwise_conv_1d(name, input_var, filter_size, in_channels, 89 | channel_multiplier=1, stride=1, dilations=1, padding="SAME", 90 | bias=None, wd=None): 91 | with tf.compat.v1.variable_scope(name) as scope: 92 | # Trainable parameters 93 | kernel = variable_with_weight_decay( 94 | "weights", 95 | shape=[filter_size, 1, in_channels, channel_multiplier], 96 | wd=wd 97 | ) 98 | 99 | # Convolution 100 | output_var = tf.nn.depthwise_conv2d( 101 | input_var, 102 | kernel, 103 | [1, stride, 1, 1], 104 | padding=padding, 105 | dilations=[dilations,dilations] 106 | ) 107 | 108 | # Bias 109 | if bias is not None: 110 | biases = _create_variable( 111 | "biases", 112 | [in_channels*channel_multiplier], 113 | tf.constant_initializer(bias) 114 | ) 115 | output_var = tf.nn.bias_add(output_var, biases) 116 | 117 | return output_var 118 | 119 | 120 | def conv_2d(name, input_var, filter_shape, stride, dilations=1, padding="SAME", 121 | bias=None, wd=None): 122 | with tf.compat.v1.variable_scope(name) as scope: 123 | # Trainable parameters 124 | kernel = variable_with_weight_decay( 125 | "weights", 126 | shape=filter_shape, 127 | wd=wd 128 | ) 129 | 130 | # Convolution 131 | output_var = tf.nn.conv2d( 132 | input_var, 133 | kernel, 134 | [1, stride[0], stride[1], 1], 135 | padding=padding, 136 | dilations=[1, dilations, dilations, 1] 137 | ) 138 | 139 | # Bias 140 | if bias is not None: 141 | biases = _create_variable( 142 | "biases", 143 | [filter_shape[-1]], 144 | tf.constant_initializer(bias) 145 | ) 146 | output_var = tf.nn.bias_add(output_var, biases) 147 | 148 | return output_var 149 | 150 | 151 | def conv_transpose_1d(name, input_var, filter_shape, output_shape, stride, dilations=1, padding="SAME", 152 | bias=None, wd=None): 153 | with tf.compat.v1.variable_scope(name) as scope: 154 | # Trainable parameters 155 | kernel = variable_with_weight_decay( 156 | "weights", 157 | shape=filter_shape, 158 | wd=wd 159 | ) 160 | 161 | # Convolution 162 | output_var = tf.nn.conv2d_transpose( 163 | input_var, 164 | kernel, 165 | output_shape=output_shape, 166 | strides=[1, stride, 1, 1], 167 | padding=padding, 168 | dilations=[1, dilations, 1, 1] 169 | ) 170 | 171 | # Bias 172 | if bias is not None: 173 | biases = _create_variable( 174 | "biases", 175 | [filter_shape[-1]], 176 | tf.constant_initializer(bias) 177 | ) 178 | output_var = tf.nn.bias_add(output_var, biases) 179 | 180 | return output_var 181 | 182 | 183 | def max_pool_1d(name, input_var, pool_size, stride, padding="SAME"): 184 | output_var = tf.nn.max_pool2d( 185 | input_var, 186 | ksize=[1, pool_size, 1, 1], 187 | strides=[1, stride, 1, 1], 188 | padding=padding, 189 | name=name 190 | ) 191 | 192 | return output_var 193 | 194 | 195 | def max_pool_2d(name, input_var, pool_size, stride, padding="SAME"): 196 | output_var = tf.nn.max_pool2d( 197 | input_var, 198 | ksize=[1, pool_size[0], pool_size[1], 1], 199 | strides=[1, stride[0], stride[1], 1], 200 | padding=padding, 201 | name=name 202 | ) 203 | 204 | return output_var 205 | 206 | 207 | def avg_pool_1d(name, input_var, pool_size, stride, padding="SAME"): 208 | output_var = tf.nn.avg_pool( 209 | input_var, 210 | ksize=[1, pool_size, 1, 1], 211 | strides=[1, stride, 1, 1], 212 | padding=padding, 213 | name=name 214 | ) 215 | 216 | return output_var 217 | 218 | 219 | def avg_pool_2d(name, input_var, pool_size, stride, padding="SAME"): 220 | output_var = tf.nn.avg_pool( 221 | input_var, 222 | ksize=[1, pool_size[0], pool_size[1], 1], 223 | strides=[1, stride[0], stride[1], 1], 224 | padding=padding, 225 | name=name 226 | ) 227 | 228 | return output_var 229 | 230 | 231 | def upsampling_1d(input_var, up_size): 232 | output_var = tf.compat.v1.keras.layers.UpSampling2D( 233 | size=(up_size, 1), 234 | data_format=None, 235 | interpolation='bilinear', 236 | )(input_var) 237 | return output_var 238 | 239 | 240 | def fc(name, input_var, n_hiddens, bias=None, wd=None): 241 | with tf.compat.v1.variable_scope(name) as scope: 242 | # Get input dimension 243 | input_dim = input_var.get_shape()[-1].value 244 | 245 | # Trainable parameters 246 | weights = variable_with_weight_decay( 247 | "weights", 248 | shape=[input_dim, n_hiddens], 249 | wd=wd 250 | ) 251 | 252 | # Multiply weights 253 | output_var = tf.matmul(input_var, weights) 254 | 255 | # Bias 256 | if bias is not None: 257 | biases = _create_variable( 258 | "biases", 259 | [n_hiddens], 260 | tf.constant_initializer(bias) 261 | ) 262 | output_var = tf.add(output_var, biases) 263 | 264 | return output_var 265 | 266 | 267 | def leaky_relu(name, input_var, alpha=0.01): 268 | return tf.maximum( 269 | input_var, 270 | alpha * input_var, 271 | name="leaky_relu" 272 | ) 273 | 274 | 275 | def batch_norm(name, input_var, is_train, decay=0.999, epsilon=1e-5): 276 | """Batch normalization on fully-connected or convolutional maps. 277 | Source: 278 | """ 279 | 280 | inputs_shape = input_var.get_shape() 281 | axis = list(range(len(inputs_shape) - 1)) 282 | params_shape = inputs_shape[-1:] 283 | 284 | with tf.compat.v1.variable_scope(name) as scope: 285 | beta = tf.compat.v1.get_variable(name="beta", shape=params_shape, 286 | initializer=tf.constant_initializer(0.0)) 287 | gamma = tf.compat.v1.get_variable(name="gamma", shape=params_shape, 288 | initializer=tf.constant_initializer(1.0)) 289 | batch_mean, batch_var = tf.nn.moments(input_var, 290 | axis, 291 | name="moments") 292 | ema = tf.train.ExponentialMovingAverage(decay=decay) 293 | 294 | def mean_var_with_update(): 295 | ema_apply_op = ema.apply([batch_mean, batch_var]) 296 | with tf.control_dependencies([ema_apply_op]): 297 | return tf.identity(batch_mean), tf.identity(batch_var) 298 | 299 | mean, var = tf.cond( 300 | is_train, 301 | mean_var_with_update, 302 | lambda: (ema.average(batch_mean), ema.average(batch_var)) 303 | ) 304 | normed = tf.nn.batch_normalization( 305 | x=input_var, 306 | mean=mean, 307 | variance=var, 308 | offset=beta, 309 | scale=gamma, 310 | variance_epsilon=epsilon, 311 | name="tf_bn" 312 | ) 313 | return normed 314 | 315 | 316 | def batch_norm_new(name, input_var, is_train, decay=0.999, epsilon=1e-5): 317 | """Batch normalization modified from BatchNormLayer in Tensorlayer. 318 | Source: 319 | """ 320 | 321 | inputs_shape = input_var.get_shape() 322 | axis = list(range(len(inputs_shape) - 1)) 323 | params_shape = inputs_shape[-1:] 324 | 325 | with tf.compat.v1.variable_scope(name) as scope: 326 | # Trainable beta and gamma variables 327 | beta = tf.compat.v1.get_variable('beta', 328 | shape=params_shape, 329 | initializer=tf.zeros_initializer()) 330 | gamma = tf.compat.v1.get_variable('gamma', 331 | shape=params_shape, 332 | initializer=tf.random_normal_initializer(mean=1.0, stddev=0.002)) 333 | 334 | # Moving mean and variance updated during training 335 | moving_mean = tf.compat.v1.get_variable('moving_mean', 336 | params_shape, 337 | initializer=tf.zeros_initializer(), 338 | trainable=False) 339 | moving_variance = tf.compat.v1.get_variable('moving_variance', 340 | params_shape, 341 | initializer=tf.constant_initializer(1.), 342 | trainable=False) 343 | 344 | # Compute mean and variance along axis 345 | batch_mean, batch_variance = tf.nn.moments(input_var, axis, name='moments') 346 | 347 | # Define ops to update moving_mean and moving_variance 348 | update_moving_mean = moving_averages.assign_moving_average(moving_mean, batch_mean, decay, zero_debias=False) 349 | update_moving_variance = moving_averages.assign_moving_average(moving_variance, batch_variance, decay, 350 | zero_debias=False) 351 | 352 | # Define a function that : 353 | # 1. Update moving_mean & moving_variance with batch_mean & batch_variance 354 | # 2. Then return the batch_mean & batch_variance 355 | def mean_var_with_update(): 356 | with tf.control_dependencies([update_moving_mean, update_moving_variance]): 357 | return tf.identity(batch_mean), tf.identity(batch_variance) 358 | 359 | # Perform different ops for training and testing 360 | if is_train: 361 | mean, variance = mean_var_with_update() 362 | normed = tf.nn.batch_normalization(input_var, mean, variance, beta, gamma, epsilon) 363 | 364 | else: 365 | normed = tf.nn.batch_normalization(input_var, moving_mean, moving_variance, beta, gamma, epsilon) 366 | # mean, variance = tf.cond( 367 | # is_train, 368 | # mean_var_with_update, # Training 369 | # lambda: (moving_mean, moving_variance) # Testing - it will use the moving_mean and moving_variance (fixed during test) that are computed during training 370 | # ) 371 | # normed = tf.nn.batch_normalization(input_var, mean, variance, beta, gamma, epsilon) 372 | 373 | return normed 374 | 375 | 376 | def flatten(name, input_var): 377 | dim = 1 378 | for d in input_var.get_shape()[1:].as_list(): 379 | dim *= d 380 | output_var = tf.reshape(input_var, 381 | shape=[-1, dim], 382 | name=name) 383 | 384 | return output_var 385 | -------------------------------------------------------------------------------- /ccrrsleep/optimize.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/akaraspt/deepsleepnet 3 | Copyright 2017 Akara Supratak and Hao Dong. All rights reserved. 4 | ''' 5 | 6 | 7 | import tensorflow as tf 8 | 9 | """AMSGrad for TensorFlow.""" 10 | 11 | from tensorflow.python.eager import context 12 | from tensorflow.python.framework import ops 13 | from tensorflow.python.ops import control_flow_ops 14 | from tensorflow.python.ops import math_ops 15 | from tensorflow.python.ops import resource_variable_ops 16 | from tensorflow.python.ops import state_ops 17 | from tensorflow.python.ops import variable_scope 18 | from tensorflow.python.training import optimizer 19 | 20 | 21 | class AMSGrad(optimizer.Optimizer): 22 | def __init__(self, learning_rate=0.01, beta1=0.9, beta2=0.99, epsilon=1e-8, use_locking=False, name="AMSGrad"): 23 | super(AMSGrad, self).__init__(use_locking, name) 24 | self._lr = learning_rate 25 | self._beta1 = beta1 26 | self._beta2 = beta2 27 | self._epsilon = epsilon 28 | 29 | self._lr_t = None 30 | self._beta1_t = None 31 | self._beta2_t = None 32 | self._epsilon_t = None 33 | 34 | self._beta1_power = None 35 | self._beta2_power = None 36 | 37 | def _create_slots(self, var_list): 38 | first_var = min(var_list, key=lambda x: x.name) 39 | 40 | create_new = self._beta1_power is None 41 | if not create_new and context.in_graph_mode(): 42 | create_new = (self._beta1_power.graph is not first_var.graph) 43 | 44 | if create_new: 45 | with ops.colocate_with(first_var): 46 | self._beta1_power = variable_scope.variable(self._beta1, name="beta1_power", trainable=False) 47 | self._beta2_power = variable_scope.variable(self._beta2, name="beta2_power", trainable=False) 48 | # Create slots for the first and second moments. 49 | for v in var_list : 50 | self._zeros_slot(v, "m", self._name) 51 | self._zeros_slot(v, "v", self._name) 52 | self._zeros_slot(v, "vhat", self._name) 53 | 54 | def _prepare(self): 55 | self._lr_t = ops.convert_to_tensor(self._lr) 56 | self._beta1_t = ops.convert_to_tensor(self._beta1) 57 | self._beta2_t = ops.convert_to_tensor(self._beta2) 58 | self._epsilon_t = ops.convert_to_tensor(self._epsilon) 59 | 60 | def _apply_dense(self, grad, var): 61 | beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) 62 | beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) 63 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 64 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 65 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 66 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 67 | 68 | lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 69 | 70 | # m_t = beta1 * m + (1 - beta1) * g_t 71 | m = self.get_slot(var, "m") 72 | m_scaled_g_values = grad * (1 - beta1_t) 73 | m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking) 74 | 75 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 76 | v = self.get_slot(var, "v") 77 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 78 | v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking) 79 | 80 | # amsgrad 81 | vhat = self.get_slot(var, "vhat") 82 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 83 | v_sqrt = math_ops.sqrt(vhat_t) 84 | 85 | var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) 86 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 87 | 88 | def _resource_apply_dense(self, grad, var): 89 | var = var.handle 90 | beta1_power = math_ops.cast(self._beta1_power, grad.dtype.base_dtype) 91 | beta2_power = math_ops.cast(self._beta2_power, grad.dtype.base_dtype) 92 | lr_t = math_ops.cast(self._lr_t, grad.dtype.base_dtype) 93 | beta1_t = math_ops.cast(self._beta1_t, grad.dtype.base_dtype) 94 | beta2_t = math_ops.cast(self._beta2_t, grad.dtype.base_dtype) 95 | epsilon_t = math_ops.cast(self._epsilon_t, grad.dtype.base_dtype) 96 | 97 | lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 98 | 99 | # m_t = beta1 * m + (1 - beta1) * g_t 100 | m = self.get_slot(var, "m").handle 101 | m_scaled_g_values = grad * (1 - beta1_t) 102 | m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking) 103 | 104 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 105 | v = self.get_slot(var, "v").handle 106 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 107 | v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking) 108 | 109 | # amsgrad 110 | vhat = self.get_slot(var, "vhat").handle 111 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 112 | v_sqrt = math_ops.sqrt(vhat_t) 113 | 114 | var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) 115 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 116 | 117 | def _apply_sparse_shared(self, grad, var, indices, scatter_add): 118 | beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) 119 | beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) 120 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 121 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 122 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 123 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 124 | 125 | lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 126 | 127 | # m_t = beta1 * m + (1 - beta1) * g_t 128 | m = self.get_slot(var, "m") 129 | m_scaled_g_values = grad * (1 - beta1_t) 130 | m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) 131 | with ops.control_dependencies([m_t]): 132 | m_t = scatter_add(m, indices, m_scaled_g_values) 133 | 134 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 135 | v = self.get_slot(var, "v") 136 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 137 | v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) 138 | with ops.control_dependencies([v_t]): 139 | v_t = scatter_add(v, indices, v_scaled_g_values) 140 | 141 | # amsgrad 142 | vhat = self.get_slot(var, "vhat") 143 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 144 | v_sqrt = math_ops.sqrt(vhat_t) 145 | var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) 146 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 147 | 148 | def _apply_sparse(self, grad, var): 149 | return self._apply_sparse_shared( 150 | grad.values, var, grad.indices, 151 | lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda 152 | x, i, v, use_locking=self._use_locking)) 153 | 154 | def _resource_scatter_add(self, x, i, v): 155 | with ops.control_dependencies( 156 | [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): 157 | return x.value() 158 | 159 | def _resource_apply_sparse(self, grad, var, indices): 160 | return self._apply_sparse_shared( 161 | grad, var, indices, self._resource_scatter_add) 162 | 163 | def _finish(self, update_ops, name_scope): 164 | # Update the power accumulators. 165 | with ops.control_dependencies(update_ops): 166 | with ops.colocate_with(self._beta1_power): 167 | update_beta1 = self._beta1_power.assign( 168 | self._beta1_power * self._beta1_t, 169 | use_locking=self._use_locking) 170 | update_beta2 = self._beta2_power.assign( 171 | self._beta2_power * self._beta2_t, 172 | use_locking=self._use_locking) 173 | return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], 174 | name=name_scope) 175 | 176 | 177 | def adam(loss, lr, train_vars, beta1=0.9, beta2=0.999, epsilon=1e-8): 178 | opt = AMSGrad( 179 | learning_rate=lr, 180 | beta1=beta1, 181 | beta2=beta2, 182 | epsilon=epsilon, 183 | ) 184 | grads_and_vars = opt.compute_gradients(loss, train_vars) 185 | apply_gradient_op = opt.apply_gradients(grads_and_vars) 186 | return apply_gradient_op, grads_and_vars 187 | 188 | 189 | def adam_clipping(loss, lr, train_vars, beta1=0.9, beta2=0.999, 190 | epsilon=1e-8, clip_value=5.0): 191 | grads, _ = tf.clip_by_global_norm(tf.gradients(loss, train_vars), 192 | clip_value) 193 | capped_gvs = list(zip(grads, train_vars)) 194 | opt = AMSGrad( 195 | learning_rate=lr, 196 | beta1=beta1, 197 | beta2=beta2, 198 | epsilon=epsilon, 199 | ) 200 | apply_gradient_op = opt.apply_gradients(capped_gvs) 201 | return apply_gradient_op, capped_gvs 202 | 203 | 204 | def adam_clipping_list_lr(loss, list_lrs, list_train_vars, 205 | beta1=0.9, beta2=0.999, 206 | epsilon=1e-8, clip_value=5.0): 207 | assert len(list_lrs) == len(list_train_vars) 208 | 209 | train_vars = [] 210 | for v in list_train_vars: 211 | if len(train_vars) == 0: 212 | train_vars = list(v) 213 | else: 214 | train_vars.extend(v) 215 | 216 | grads, _ = tf.clip_by_global_norm(tf.gradients(loss, train_vars), 217 | clip_value) 218 | 219 | offset = 0 220 | apply_gradient_ops = [] 221 | grads_and_vars = [] 222 | for i, v in enumerate(list_train_vars): 223 | g = grads[offset:offset+len(v)] 224 | opt = AMSGrad( 225 | learning_rate=list_lrs[i], 226 | beta1=beta1, 227 | beta2=beta2, 228 | epsilon=epsilon, 229 | ) 230 | gvs = list(zip(g, v)) 231 | apply_gradient_op = opt.apply_gradients(gvs) 232 | 233 | apply_gradient_ops.append(apply_gradient_op) 234 | if len(grads_and_vars) == 0: 235 | grads_and_vars = list(gvs) 236 | else: 237 | grads_and_vars.extend(gvs) 238 | offset += len(v) 239 | 240 | apply_gradient_ops = tf.group(*apply_gradient_ops) 241 | return apply_gradient_ops, grads_and_vars 242 | 243 | 244 | # 245 | # 246 | # def adam(loss, lr, train_vars, beta1=0.9, beta2=0.999, epsilon=1e-8): 247 | # opt = tf.compat.v1.train.AdamOptimizer( 248 | # learning_rate=lr, 249 | # beta1=beta1, 250 | # beta2=beta2, 251 | # epsilon=epsilon, 252 | # name="Adam" 253 | # ) 254 | # grads_and_vars = opt.compute_gradients(loss, train_vars) 255 | # apply_gradient_op = opt.apply_gradients(grads_and_vars) 256 | # return apply_gradient_op, grads_and_vars 257 | # 258 | # 259 | # def adam_clipping(loss, lr, train_vars, beta1=0.9, beta2=0.999, 260 | # epsilon=1e-8, clip_value=5.0): 261 | # grads, _ = tf.clip_by_global_norm(tf.gradients(loss, train_vars), 262 | # clip_value) 263 | # capped_gvs = list(zip(grads, train_vars)) 264 | # opt = tf.compat.v1.train.AdamOptimizer( 265 | # learning_rate=lr, 266 | # beta1=beta1, 267 | # beta2=beta2, 268 | # epsilon=epsilon, 269 | # name="Adam" 270 | # ) 271 | # apply_gradient_op = opt.apply_gradients(capped_gvs) 272 | # return apply_gradient_op, capped_gvs 273 | # 274 | # 275 | # def adam_clipping_list_lr(loss, list_lrs, list_train_vars, 276 | # beta1=0.9, beta2=0.999, 277 | # epsilon=1e-8, clip_value=5.0): 278 | # assert len(list_lrs) == len(list_train_vars) 279 | # 280 | # train_vars = [] 281 | # for v in list_train_vars: 282 | # if len(train_vars) == 0: 283 | # train_vars = list(v) 284 | # else: 285 | # train_vars.extend(v) 286 | # 287 | # grads, _ = tf.clip_by_global_norm(tf.gradients(loss, train_vars), 288 | # clip_value) 289 | # 290 | # offset = 0 291 | # apply_gradient_ops = [] 292 | # grads_and_vars = [] 293 | # for i, v in enumerate(list_train_vars): 294 | # g = grads[offset:offset+len(v)] 295 | # opt = tf.compat.v1.train.AdamOptimizer( 296 | # learning_rate=list_lrs[i], 297 | # beta1=beta1, 298 | # beta2=beta2, 299 | # epsilon=epsilon, 300 | # name="Adam" 301 | # ) 302 | # gvs = list(zip(g, v)) 303 | # apply_gradient_op = opt.apply_gradients(gvs) 304 | # 305 | # apply_gradient_ops.append(apply_gradient_op) 306 | # if len(grads_and_vars) == 0: 307 | # grads_and_vars = list(gvs) 308 | # else: 309 | # grads_and_vars.extend(gvs) 310 | # offset += len(v) 311 | # 312 | # apply_gradient_ops = tf.group(*apply_gradient_ops) 313 | # return apply_gradient_ops, grads_and_vars 314 | -------------------------------------------------------------------------------- /ccrrsleep/sleep_stage.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/akaraspt/deepsleepnet 3 | Copyright 2017 Akara Supratak and Hao Dong. All rights reserved. 4 | ''' 5 | 6 | 7 | # Label values 8 | W = 0 9 | N1 = 1 10 | N2 = 2 11 | N3 = 3 12 | REM = 4 13 | UNKNOWN = 5 14 | 15 | NUM_CLASSES = 5 # exclude UNKNOWN 16 | 17 | class_dict = { 18 | 0: "W", 19 | 1: "N1", 20 | 2: "N2", 21 | 3: "N3", 22 | 4: "REM" 23 | } 24 | 25 | EPOCH_SEC_LEN = 30 # seconds 26 | SAMPLING_RATE = 256 27 | 28 | def print_n_samples_each_class(labels): 29 | import numpy as np 30 | unique_labels = np.unique(labels) 31 | for c in unique_labels: 32 | n_samples = len(np.where(labels == c)[0]) 33 | print("{}: {}".format(class_dict[c], n_samples)) 34 | -------------------------------------------------------------------------------- /ccrrsleep/trainer.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import re 4 | import time 5 | 6 | from datetime import datetime 7 | 8 | import matplotlib 9 | 10 | matplotlib.use("Agg") 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import tensorflow as tf 14 | from sklearn.metrics import confusion_matrix, f1_score 15 | 16 | from ccrrsleep.data_loader import NonSeqDataLoader, SeqDataLoader 17 | from ccrrsleep.model import CCRRFeatureNet, CCRRSleepNet 18 | from ccrrsleep.optimize import adam, adam_clipping_list_lr 19 | from ccrrsleep.utils import iterate_minibatches, iterate_batch_seq_minibatches 20 | 21 | 22 | class Trainer(object): 23 | 24 | def __init__( 25 | self, 26 | interval_plot_filter=50, 27 | interval_save_model=100, 28 | interval_print_cm=10 29 | ): 30 | self.interval_plot_filter = interval_plot_filter # 间隔多久画出特征图 31 | self.interval_save_model = interval_save_model # 间隔多久保存模型 32 | self.interval_print_cm = interval_print_cm # 间隔多久打印混淆矩阵 33 | 34 | def print_performance(self, sess, output_dir, network_name, 35 | n_train_examples, n_valid_examples, 36 | train_cm, valid_cm, epoch, n_epochs, 37 | train_duration, train_loss, train_acc, train_f1, 38 | valid_duration, valid_loss, valid_acc, valid_f1): 39 | # Get regularization loss 获得正则化损失 40 | train_reg_loss = tf.add_n(tf.compat.v1.get_collection("losses", scope=network_name + "\/")) 41 | train_reg_loss_value = sess.run(train_reg_loss) 42 | valid_reg_loss_value = train_reg_loss_value 43 | 44 | # Print performance 45 | if ((epoch + 1) % self.interval_print_cm == 0) or ((epoch + 1) == n_epochs): 46 | print(" ") 47 | print("[{}] epoch {}:".format( 48 | datetime.now(), epoch + 1 49 | )) 50 | print(( 51 | "train ({:.3f} sec): n={}, loss={:.3f} ({:.3f}), acc={:.3f}, " 52 | "f1={:.3f}".format( 53 | train_duration, n_train_examples, 54 | train_loss, train_reg_loss_value, 55 | train_acc, train_f1 56 | ) 57 | )) 58 | print(train_cm) 59 | print(( 60 | "valid ({:.3f} sec): n={}, loss={:.3f} ({:.3f}), acc={:.3f}, " 61 | "f1={:.3f}".format( 62 | valid_duration, n_valid_examples, 63 | valid_loss, valid_reg_loss_value, 64 | valid_acc, valid_f1 65 | ) 66 | )) 67 | print(valid_cm) 68 | print(" ") 69 | else: 70 | print(( 71 | "epoch {}: " 72 | "train ({:.2f} sec): n={}, loss={:.3f} ({:.3f}), " 73 | "acc={:.3f}, f1={:.3f} | " 74 | "valid ({:.2f} sec): n={}, loss={:.3f} ({:.3f}), " 75 | "acc={:.3f}, f1={:.3f}".format( 76 | epoch + 1, 77 | train_duration, n_train_examples, 78 | train_loss, train_reg_loss_value, 79 | train_acc, train_f1, 80 | valid_duration, n_valid_examples, 81 | valid_loss, valid_reg_loss_value, 82 | valid_acc, valid_f1 83 | ) 84 | )) 85 | 86 | def print_network(self, network): 87 | print("inputs ({}): {}".format( 88 | network.inputs.name, network.inputs.get_shape() 89 | )) 90 | print("targets ({}): {}".format( 91 | network.targets.name, network.targets.get_shape() 92 | )) 93 | for name, act in network.activations: 94 | print("{} ({}): {}".format(name, act.name, act.get_shape())) 95 | print(" ") 96 | 97 | def plot_filters(self, sess, epoch, reg_exp, output_dir, n_viz_filters): 98 | conv_weight = re.compile(reg_exp) 99 | for v in tf.compat.v1.trainable_variables(): 100 | 101 | value = sess.run(v) 102 | if conv_weight.match(v.name): 103 | weights = np.squeeze(value) 104 | # Only plot conv that has one channel 105 | if len(weights.shape) > 2: 106 | continue 107 | weights = weights.T 108 | plt.figure(figsize=(18, 10)) 109 | plt.title(v.name) 110 | for w_idx in range(n_viz_filters): 111 | plt.subplot(4, 4, w_idx + 1) 112 | plt.plot(weights[w_idx]) 113 | plt.axis("tight") 114 | plt.savefig(os.path.join( 115 | output_dir, "{}_{}.png".format( 116 | v.name.replace("/", "_").replace(":0", ""), 117 | epoch + 1 118 | ) 119 | )) 120 | plt.close("all") 121 | 122 | 123 | class CCRRFeatureNetTrainer(Trainer): 124 | 125 | def __init__( 126 | self, 127 | data_dir, 128 | output_dir, 129 | n_folds, 130 | fold_idx, 131 | batch_size, 132 | input_dims, 133 | n_classes, 134 | interval_plot_filter=50, 135 | interval_save_model=100, 136 | interval_print_cm=10 137 | ): 138 | super(self.__class__, self).__init__( 139 | interval_plot_filter=interval_plot_filter, 140 | interval_save_model=interval_save_model, 141 | interval_print_cm=interval_print_cm 142 | ) 143 | 144 | self.data_dir = data_dir 145 | self.output_dir = output_dir 146 | self.n_folds = n_folds 147 | self.fold_idx = fold_idx 148 | self.batch_size = batch_size 149 | self.input_dims = input_dims 150 | self.n_classes = n_classes 151 | 152 | def _run_epoch(self, sess, network, inputs, targets, train_op, is_train): 153 | start_time = time.time() 154 | y = [] 155 | y_true = [] 156 | total_loss, n_batches = 0.0, 0 157 | is_shuffle = True if is_train else False 158 | for x_batch, y_batch in iterate_minibatches(inputs, 159 | targets, 160 | self.batch_size, 161 | shuffle=is_shuffle): 162 | feed_dict = { 163 | network.input_var: x_batch, 164 | network.target_var: y_batch 165 | } 166 | 167 | _, loss_value, y_pred = sess.run( 168 | [train_op, network.loss_op, network.pred_op], 169 | feed_dict=feed_dict 170 | ) 171 | 172 | total_loss += loss_value 173 | n_batches += 1 174 | y.append(y_pred) 175 | y_true.append(y_batch) 176 | 177 | duration = time.time() - start_time 178 | total_loss /= n_batches 179 | total_y_pred = np.hstack(y) 180 | total_y_true = np.hstack(y_true) 181 | 182 | return total_y_true, total_y_pred, total_loss, duration 183 | 184 | def train(self, n_epochs, resume): 185 | with tf.Graph().as_default(), tf.compat.v1.Session() as sess: 186 | # Build training and validation networks 187 | train_net = CCRRFeatureNet( 188 | batch_size=self.batch_size, 189 | input_dims=self.input_dims, 190 | n_classes=self.n_classes, 191 | is_train=True, 192 | reuse_params=False, 193 | use_dropout=True 194 | ) 195 | valid_net = CCRRFeatureNet( 196 | batch_size=self.batch_size, 197 | input_dims=self.input_dims, 198 | n_classes=self.n_classes, 199 | is_train=False, 200 | reuse_params=True, 201 | use_dropout=True 202 | ) 203 | 204 | # Initialize parameters 205 | train_net.init_ops() 206 | valid_net.init_ops() 207 | 208 | print("Network (layers={})".format(len(train_net.activations))) 209 | print("inputs ({}): {}".format( 210 | train_net.input_var.name, train_net.input_var.get_shape() 211 | )) 212 | print("targets ({}): {}".format( 213 | train_net.target_var.name, train_net.target_var.get_shape() 214 | )) 215 | for name, act in train_net.activations: 216 | print("{} ({}): {}".format(name, act.name, act.get_shape())) 217 | print(" ") 218 | # Global step for resume training 219 | with tf.compat.v1.variable_scope(train_net.name) as scope: 220 | global_step = tf.Variable(0, name="global_step", trainable=False) 221 | 222 | # boundaries = [10, 20, 30] 223 | # values = [1e-3, 1e-4, 5e-5, 1e-5] 224 | boundaries = [40] 225 | values = [1e-3, 1e-4] 226 | learning_rate = tf.train.piecewise_constant(global_step, boundaries, values) 227 | # learning_rate = 1e-4 228 | # Define optimization operations 229 | train_op, grads_and_vars_op = adam( 230 | loss=train_net.loss_op, 231 | lr=learning_rate, 232 | train_vars=tf.compat.v1.trainable_variables() 233 | ) 234 | 235 | # Make subdirectory for pretraining 236 | output_dir = os.path.join(self.output_dir, "fold{}".format(self.fold_idx), train_net.name) 237 | if not os.path.exists(output_dir): 238 | os.makedirs(output_dir) 239 | 240 | # Global step for resume training 241 | # with tf.compat.v1.variable_scope(train_net.name) as scope: 242 | # global_step = tf.Variable(0, name="global_step", trainable=False) 243 | 244 | # Create a saver 245 | saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=0) 246 | 247 | # Initialize variables in the graph 248 | sess.run(tf.compat.v1.global_variables_initializer()) 249 | 250 | # Resume the training if applicable 251 | if resume: 252 | if os.path.exists(output_dir): 253 | if os.path.isfile(os.path.join(output_dir, "checkpoint")): 254 | # Restore the last checkpoint 255 | saver.restore(sess, tf.train.latest_checkpoint(output_dir)) 256 | print("Model restored") 257 | print("[{}] Resume pre-training ...\n".format(datetime.now())) 258 | else: 259 | print("[{}] Start pre-training ...\n".format(datetime.now())) 260 | else: 261 | print("[{}] Start pre-training ...\n".format(datetime.now())) 262 | 263 | # Load data 264 | if sess.run(global_step) < n_epochs: 265 | data_loader = NonSeqDataLoader( 266 | data_dir=self.data_dir, 267 | n_folds=self.n_folds, 268 | fold_idx=self.fold_idx 269 | ) 270 | 271 | x_train, y_train, x_valid, y_valid = data_loader.load_train_data() 272 | 273 | # Performance history  274 | all_train_loss = np.zeros(n_epochs) 275 | all_train_acc = np.zeros(n_epochs) 276 | all_train_f1 = np.zeros(n_epochs) 277 | all_valid_loss = np.zeros(n_epochs) 278 | all_valid_acc = np.zeros(n_epochs) 279 | all_valid_f1 = np.zeros(n_epochs) 280 | 281 | # Loop each epoch  282 | for epoch in range(sess.run(global_step), n_epochs): 283 | print(sess.run(learning_rate),end='') 284 | # Update parameters and compute loss of training set 285 | y_true_train, y_pred_train, train_loss, train_duration = \ 286 | self._run_epoch( 287 | sess=sess, network=train_net, 288 | inputs=x_train, targets=y_train, 289 | train_op=train_op, 290 | is_train=True 291 | ) 292 | n_train_examples = len(y_true_train) 293 | train_cm = confusion_matrix(y_true_train, y_pred_train) 294 | train_acc = np.mean(y_true_train == y_pred_train) 295 | train_f1 = f1_score(y_true_train, y_pred_train, average="macro") 296 | 297 | # Evaluate the model on the validation set 298 | y_true_val, y_pred_val, valid_loss, valid_duration = \ 299 | self._run_epoch( 300 | sess=sess, network=valid_net, 301 | inputs=x_valid, targets=y_valid, 302 | train_op=tf.no_op(), 303 | is_train=False 304 | ) 305 | n_valid_examples = len(y_true_val) 306 | valid_cm = confusion_matrix(y_true_val, y_pred_val) 307 | valid_acc = np.mean(y_true_val == y_pred_val) 308 | valid_f1 = f1_score(y_true_val, y_pred_val, average="macro") 309 | 310 | all_train_loss[epoch] = train_loss 311 | all_train_acc[epoch] = train_acc 312 | all_train_f1[epoch] = train_f1 313 | all_valid_loss[epoch] = valid_loss 314 | all_valid_acc[epoch] = valid_acc 315 | all_valid_f1[epoch] = valid_f1 316 | 317 | # Report performance 318 | self.print_performance( 319 | sess, output_dir, train_net.name, 320 | n_train_examples, n_valid_examples, 321 | train_cm, valid_cm, epoch, n_epochs, 322 | train_duration, train_loss, train_acc, train_f1, 323 | valid_duration, valid_loss, valid_acc, valid_f1 324 | ) 325 | 326 | # Save performance history 327 | np.savez( 328 | os.path.join(output_dir, "perf_fold{}.npz".format(self.fold_idx)), 329 | train_loss=all_train_loss, valid_loss=all_valid_loss, 330 | train_acc=all_train_acc, valid_acc=all_valid_acc, 331 | train_f1=all_train_f1, valid_f1=all_valid_f1, 332 | y_true_val=np.asarray(y_true_val), 333 | y_pred_val=np.asarray(y_pred_val) 334 | ) 335 | # Save checkpoint 336 | 337 | # print(sess.run(learning_rate)) 338 | 339 | sess.run(tf.compat.v1.assign(global_step, epoch + 1)) 340 | if ((epoch + 1) % self.interval_save_model == 0) or ((epoch + 1) == n_epochs): 341 | start_time = time.time() 342 | save_path = os.path.join( 343 | output_dir, "model_fold{}.ckpt".format(self.fold_idx) 344 | ) 345 | # saver.save(sess, save_path, global_step=global_step) 346 | saver.save(sess, save_path) 347 | duration = time.time() - start_time 348 | print("Saved model checkpoint ({:.3f} sec)".format(duration)) 349 | 350 | # Save paramaters 351 | if ((epoch + 1) % self.interval_save_model == 0) or ((epoch + 1) == n_epochs): 352 | start_time = time.time() 353 | save_dict = {} 354 | for v in tf.compat.v1.global_variables(): 355 | save_dict[v.name] = sess.run(v) 356 | np.savez( 357 | os.path.join( 358 | output_dir, 359 | "params_fold{}.npz".format(self.fold_idx)), 360 | **save_dict 361 | ) 362 | duration = time.time() - start_time 363 | print("Saved trained parameters ({:.3f} sec)".format(duration)) 364 | 365 | print("Finish pre-training") 366 | return os.path.join(output_dir, "params_fold{}.npz".format(self.fold_idx)) 367 | 368 | 369 | class CCRRSleepNetTrainer(Trainer): 370 | 371 | def __init__( 372 | self, 373 | data_dir, 374 | output_dir, 375 | n_folds, 376 | fold_idx, 377 | batch_size, 378 | input_dims, 379 | n_classes, 380 | seq_length, 381 | n_rnn_layers, 382 | return_last, 383 | interval_plot_filter=50, 384 | interval_save_model=100, 385 | interval_print_cm=10 386 | ): 387 | super(self.__class__, self).__init__( 388 | interval_plot_filter=interval_plot_filter, 389 | interval_save_model=interval_save_model, 390 | interval_print_cm=interval_print_cm 391 | ) 392 | 393 | self.data_dir = data_dir 394 | self.output_dir = output_dir 395 | self.n_folds = n_folds 396 | self.fold_idx = fold_idx 397 | self.batch_size = batch_size 398 | self.input_dims = input_dims 399 | self.n_classes = n_classes 400 | self.seq_length = seq_length 401 | self.n_rnn_layers = n_rnn_layers 402 | self.return_last = return_last 403 | 404 | def _run_epoch(self, sess, network, inputs, targets, train_op, is_train): 405 | start_time = time.time() 406 | y = [] 407 | y_true = [] 408 | total_loss, n_batches = 0.0, 0 409 | for sub_idx, each_data in enumerate(zip(inputs, targets)): 410 | each_x, each_y = each_data 411 | 412 | # # Initialize state of LSTM - Unidirectional LSTM 413 | # state = sess.run(network.initial_state) 414 | 415 | # Initialize state of LSTM - Bidirectional LSTM 416 | fw_state = sess.run(network.fw_initial_state) 417 | bw_state = sess.run(network.bw_initial_state) 418 | 419 | for x_batch, y_batch in iterate_batch_seq_minibatches(inputs=each_x, 420 | targets=each_y, 421 | batch_size=self.batch_size, 422 | seq_length=self.seq_length): 423 | feed_dict = { 424 | network.input_var: x_batch, 425 | network.target_var: y_batch 426 | } 427 | 428 | for i, (c, h) in enumerate(network.fw_initial_state): 429 | feed_dict[c] = fw_state[i].c 430 | feed_dict[h] = fw_state[i].h 431 | 432 | for i, (c, h) in enumerate(network.bw_initial_state): 433 | feed_dict[c] = bw_state[i].c 434 | feed_dict[h] = bw_state[i].h 435 | 436 | _, loss_value, y_pred, fw_state, bw_state = sess.run( 437 | [train_op, network.loss_op, network.pred_op, network.fw_final_state, network.bw_final_state], 438 | feed_dict=feed_dict 439 | ) 440 | 441 | total_loss += loss_value 442 | n_batches += 1 443 | y.append(y_pred) 444 | y_true.append(y_batch) 445 | 446 | duration = time.time() - start_time 447 | total_loss /= n_batches 448 | total_y_pred = np.hstack(y) 449 | total_y_true = np.hstack(y_true) 450 | 451 | return total_y_true, total_y_pred, total_loss, duration 452 | 453 | def finetune(self, pretrained_model_path, n_epochs, resume): 454 | pretrained_model_name = "ccrrfeaturenet" 455 | 456 | with tf.Graph().as_default(), tf.compat.v1.Session() as sess: 457 | # Build training and validation networks 458 | train_net = CCRRSleepNet( 459 | batch_size=self.batch_size, 460 | input_dims=self.input_dims, 461 | n_classes=self.n_classes, 462 | seq_length=self.seq_length, 463 | n_rnn_layers=self.n_rnn_layers, 464 | return_last=self.return_last, 465 | is_train=True, 466 | reuse_params=False, 467 | use_dropout_feature=True, 468 | use_dropout_sequence=True 469 | ) 470 | valid_net = CCRRSleepNet( 471 | batch_size=self.batch_size, 472 | input_dims=self.input_dims, 473 | n_classes=self.n_classes, 474 | seq_length=self.seq_length, 475 | n_rnn_layers=self.n_rnn_layers, 476 | return_last=self.return_last, 477 | is_train=False, 478 | reuse_params=True, 479 | use_dropout_feature=True, 480 | use_dropout_sequence=True 481 | ) 482 | 483 | # Initialize parameters 484 | train_net.init_ops() 485 | valid_net.init_ops() 486 | 487 | print("Network (layers={})".format(len(train_net.activations))) 488 | print("inputs ({}): {}".format( 489 | train_net.input_var.name, train_net.input_var.get_shape() 490 | )) 491 | print("targets ({}): {}".format( 492 | train_net.target_var.name, train_net.target_var.get_shape() 493 | )) 494 | for name, act in train_net.activations: 495 | print("{} ({}): {}".format(name, act.name, act.get_shape())) 496 | print(" ") 497 | 498 | # Get list of all pretrained parameters 499 | with np.load(pretrained_model_path) as f: 500 | pretrain_params = list(f.keys()) 501 | 502 | # Remove the network-name-prefix 503 | for i in range(len(pretrain_params)): 504 | pretrain_params[i] = pretrain_params[i].replace(pretrained_model_name, "network") 505 | 506 | # Get trainable variables of the pretrained, and new ones 507 | train_vars1 = [v for v in tf.compat.v1.trainable_variables() 508 | if v.name.replace(train_net.name, "network") in pretrain_params] 509 | train_vars2 = list(set(tf.compat.v1.trainable_variables()) - set(train_vars1)) 510 | with tf.compat.v1.variable_scope(train_net.name) as scope: 511 | global_step = tf.Variable(0, name="global_step", trainable=False) 512 | 513 | boundaries = [10, 20, 30] 514 | values = [1e-3, 1e-4, 5e-5, 1e-5] 515 | learning_rate = tf.train.piecewise_constant(global_step, boundaries, values) 516 | # learning_rate = 1e-4 517 | # Optimizer that use different learning rates for each part of the network 518 | train_op, grads_and_vars_op = adam_clipping_list_lr( 519 | loss=train_net.loss_op, 520 | list_lrs=[1e-6, learning_rate], 521 | list_train_vars=[train_vars1, train_vars2], 522 | clip_value=10.0 523 | ) 524 | 525 | # Make subdirectory for pretraining 526 | output_dir = os.path.join(self.output_dir, "fold{}".format(self.fold_idx), train_net.name) 527 | if not os.path.exists(output_dir): 528 | os.makedirs(output_dir) 529 | 530 | # Global step for resume training 531 | 532 | 533 | 534 | # Create a saver 535 | saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=0) 536 | 537 | # Initialize variables in the graph 538 | sess.run(tf.compat.v1.global_variables_initializer()) 539 | 540 | # Resume the training if applicable 541 | load_pretrain = False 542 | if resume: 543 | if os.path.exists(output_dir): 544 | if os.path.isfile(os.path.join(output_dir, "checkpoint")): 545 | # Restore the last checkpoint 546 | saver.restore(sess, tf.train.latest_checkpoint(output_dir)) 547 | print("Model restored") 548 | print("[{}] Resume fine-tuning ...\n".format(datetime.now())) 549 | else: 550 | load_pretrain = True 551 | else: 552 | load_pretrain = True 553 | 554 | if load_pretrain: 555 | # Load pre-trained model 556 | print("Loading pre-trained parameters to the model ...") 557 | print(" | --> {} from {}".format(pretrained_model_name, pretrained_model_path)) 558 | with np.load(pretrained_model_path) as f: 559 | for k, v in f.items(): 560 | if "Adam" in k or "softmax" in k or "power" in k or "global_step" in k: 561 | continue 562 | prev_k = k 563 | k = k.replace(pretrained_model_name, train_net.name) 564 | tmp_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name(k) 565 | sess.run( 566 | tf.compat.v1.assign( 567 | tmp_tensor, 568 | v 569 | ) 570 | ) 571 | print("assigned {}: {} to {}: {}".format( 572 | prev_k, v.shape, k, tmp_tensor.get_shape() 573 | )) 574 | print(" ") 575 | print("[{}] Start fine-tuning ...\n".format(datetime.now())) 576 | 577 | # Load data 578 | if sess.run(global_step) < n_epochs: 579 | data_loader = SeqDataLoader( 580 | data_dir=self.data_dir, 581 | n_folds=self.n_folds, 582 | fold_idx=self.fold_idx 583 | ) 584 | x_train, y_train, x_valid, y_valid = data_loader.load_train_data() 585 | 586 | # Performance history 587 | all_train_loss = np.zeros(n_epochs) 588 | all_train_acc = np.zeros(n_epochs) 589 | all_train_f1 = np.zeros(n_epochs) 590 | all_valid_loss = np.zeros(n_epochs) 591 | all_valid_acc = np.zeros(n_epochs) 592 | all_valid_f1 = np.zeros(n_epochs) 593 | 594 | 595 | # Loop each epoch 596 | for epoch in range(sess.run(global_step), n_epochs): 597 | print(sess.run(learning_rate),end='') 598 | # Update parameters and compute loss of training set 599 | y_true_train, y_pred_train, train_loss, train_duration = \ 600 | self._run_epoch( 601 | sess=sess, network=train_net, 602 | inputs=x_train, targets=y_train, 603 | train_op=train_op, 604 | is_train=True 605 | ) 606 | n_train_examples = len(y_true_train) 607 | train_cm = confusion_matrix(y_true_train, y_pred_train) 608 | train_acc = np.mean(y_true_train == y_pred_train) 609 | train_f1 = f1_score(y_true_train, y_pred_train, average="macro") 610 | 611 | # Evaluate the model on the validation set 612 | y_true_val, y_pred_val, valid_loss, valid_duration = \ 613 | self._run_epoch( 614 | sess=sess, network=valid_net, 615 | inputs=x_valid, targets=y_valid, 616 | train_op=tf.no_op(), 617 | is_train=False 618 | ) 619 | n_valid_examples = len(y_true_val) 620 | valid_cm = confusion_matrix(y_true_val, y_pred_val) 621 | valid_acc = np.mean(y_true_val == y_pred_val) 622 | valid_f1 = f1_score(y_true_val, y_pred_val, average="macro") 623 | 624 | all_train_loss[epoch] = train_loss 625 | all_train_acc[epoch] = train_acc 626 | all_train_f1[epoch] = train_f1 627 | all_valid_loss[epoch] = valid_loss 628 | all_valid_acc[epoch] = valid_acc 629 | all_valid_f1[epoch] = valid_f1 630 | 631 | # Report performance 632 | self.print_performance( 633 | sess, output_dir, train_net.name, 634 | n_train_examples, n_valid_examples, 635 | train_cm, valid_cm, epoch, n_epochs, 636 | train_duration, train_loss, train_acc, train_f1, 637 | valid_duration, valid_loss, valid_acc, valid_f1 638 | ) 639 | 640 | # Save performance history 641 | np.savez( 642 | os.path.join(output_dir, "perf_fold{}.npz".format(self.fold_idx)), 643 | train_loss=all_train_loss, valid_loss=all_valid_loss, 644 | train_acc=all_train_acc, valid_acc=all_valid_acc, 645 | train_f1=all_train_f1, valid_f1=all_valid_f1, 646 | y_true_val=np.asarray(y_true_val), 647 | y_pred_val=np.asarray(y_pred_val) 648 | ) 649 | 650 | # Visualize weights from convolutional layers 651 | if ((epoch + 1) % self.interval_plot_filter == 0) or ((epoch + 1) == n_epochs): 652 | self.plot_filters(sess, epoch, train_net.name + "(_[0-9])?\/l[0-9]+_conv\/(weights)", output_dir, 653 | 16) 654 | self.plot_filters(sess, epoch, train_net.name + "(_[0-9])?/l[0-9]+_conv\/conv1d\/(weights)", 655 | output_dir, 16) 656 | 657 | # Save checkpoint 658 | sess.run(tf.compat.v1.assign(global_step, epoch + 1)) 659 | sess.run(tf.compat.v1.assign(global_step, epoch + 1)) 660 | if ((epoch + 1) % self.interval_save_model == 0) or ((epoch + 1) == n_epochs): 661 | start_time = time.time() 662 | save_path = os.path.join( 663 | output_dir, "model_fold{}.ckpt".format(self.fold_idx) 664 | ) 665 | # saver.save(sess, save_path, global_step=global_step) 666 | saver.save(sess, save_path) 667 | duration = time.time() - start_time 668 | print("Saved model checkpoint ({:.3f} sec)".format(duration)) 669 | 670 | # Save paramaters 671 | if ((epoch + 1) % self.interval_save_model == 0) or ((epoch + 1) == n_epochs): 672 | start_time = time.time() 673 | save_dict = {} 674 | for v in tf.compat.v1.global_variables(): 675 | save_dict[v.name] = sess.run(v) 676 | np.savez( 677 | os.path.join( 678 | output_dir, 679 | "params_fold{}.npz".format(self.fold_idx)), 680 | **save_dict 681 | ) 682 | duration = time.time() - start_time 683 | print("Saved trained parameters ({:.3f} sec)".format(duration)) 684 | 685 | print("Finish fine-tuning") 686 | return os.path.join(output_dir, "params_fold{}.npz".format(self.fold_idx)) 687 | 688 | -------------------------------------------------------------------------------- /ccrrsleep/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | from collections import Counter 4 | # from scipy.ndimage.interpolation import shift 5 | from scipy.ndimage import shift 6 | 7 | 8 | 9 | def get_balance_class_downsample(x, y): 10 | """ 11 | Balance the number of samples of all classes by (downsampling): 12 | 1. Find the class that has a smallest number of samples 13 | 2. Randomly select samples in each class equal to that smallest number 14 | """ 15 | 16 | class_labels = np.unique(y) 17 | n_min_classes = -1 18 | for c in class_labels: 19 | n_samples = len(np.where(y == c)[0]) 20 | if n_min_classes == -1: 21 | n_min_classes = n_samples 22 | elif n_min_classes > n_samples: 23 | n_min_classes = n_samples 24 | 25 | balance_x = [] 26 | balance_y = [] 27 | for c in class_labels: 28 | idx = np.where(y == c)[0] 29 | idx = np.random.permutation(idx)[:n_min_classes] 30 | balance_x.append(x[idx]) 31 | balance_y.append(y[idx]) 32 | balance_x = np.vstack(balance_x) 33 | balance_y = np.hstack(balance_y) 34 | 35 | return balance_x, balance_y 36 | 37 | 38 | def get_balance_class_oversample(x, y): 39 | """ 40 | Balance the number of samples of all classes by (oversampling): 41 | 1. Find the class that has the largest number of samples 42 | 2. Randomly select samples in each class equal to that largest number 43 | """ 44 | 45 | from imblearn.over_sampling import SMOTE, BorderlineSMOTE,RandomOverSampler 46 | from imblearn.under_sampling import RandomUnderSampler 47 | # sm = SMOTE(random_state=1, n_jobs=32) 48 | sm = BorderlineSMOTE(random_state=1, kind="borderline-1", n_jobs=32) 49 | # sm = RandomOverSampler(random_state=1) 50 | # sm = RandomUnderSampler(random_state=1) 51 | 52 | x = np.squeeze(x) 53 | balance_x, balance_y = sm.fit_sample(x, y) 54 | balance_x = np.expand_dims(balance_x, axis=2) 55 | balance_x = np.expand_dims(balance_x, axis=3) 56 | return balance_x, balance_y 57 | 58 | ######## 59 | def iterate_minibatches(inputs, targets, batch_size, shuffle=False): 60 | """ 61 | Generate a generator that return a batch of inputs and targets. 62 | """ 63 | assert len(inputs) == len(targets) 64 | if shuffle: 65 | indices = np.arange(len(inputs)) 66 | np.random.shuffle(indices) 67 | for start_idx in range(0, len(inputs) - batch_size + 1, batch_size): 68 | if shuffle: 69 | excerpt = indices[start_idx:start_idx + batch_size] 70 | else: 71 | excerpt = slice(start_idx, start_idx + batch_size) 72 | if shuffle: 73 | yield DataAugmentation(inputs[excerpt]), targets[excerpt] 74 | else: 75 | yield inputs[excerpt], targets[excerpt] 76 | 77 | 78 | def iterate_seq_minibatches(inputs, targets, batch_size, seq_length, stride): 79 | """ 80 | Generate a generator that return a batch of sequence inputs and targets. 81 | """ 82 | assert len(inputs) == len(targets) 83 | n_loads = (batch_size * stride) + (seq_length - stride) 84 | for start_idx in range(0, len(inputs) - n_loads + 1, (batch_size * stride)): 85 | seq_inputs = np.zeros((batch_size, seq_length) + inputs.shape[1:], 86 | dtype=inputs.dtype) 87 | seq_targets = np.zeros((batch_size, seq_length) + targets.shape[1:], 88 | dtype=targets.dtype) 89 | for b_idx in range(batch_size): 90 | start_seq_idx = start_idx + (b_idx * stride) 91 | end_seq_idx = start_seq_idx + seq_length 92 | seq_inputs[b_idx] = inputs[start_seq_idx:end_seq_idx] 93 | seq_targets[b_idx] = targets[start_seq_idx:end_seq_idx] 94 | flatten_inputs = seq_inputs.reshape((-1,) + inputs.shape[1:]) 95 | flatten_targets = seq_targets.reshape((-1,) + targets.shape[1:]) 96 | yield flatten_inputs, flatten_targets 97 | 98 | ############ 99 | def iterate_batch_seq_minibatches(inputs, targets, batch_size, seq_length): 100 | assert len(inputs) == len(targets) 101 | n_inputs = len(inputs) 102 | batch_len = n_inputs // batch_size 103 | 104 | epoch_size = batch_len // seq_length 105 | if epoch_size == 0: 106 | raise ValueError("epoch_size == 0, decrease batch_size or seq_length") 107 | 108 | seq_inputs = np.zeros((batch_size, batch_len) + inputs.shape[1:], 109 | dtype=inputs.dtype) 110 | seq_targets = np.zeros((batch_size, batch_len) + targets.shape[1:], 111 | dtype=targets.dtype) 112 | 113 | for i in range(batch_size): 114 | seq_inputs[i] = inputs[i*batch_len:(i+1)*batch_len] 115 | seq_targets[i] = targets[i*batch_len:(i+1)*batch_len] 116 | 117 | for i in range(epoch_size): 118 | x = seq_inputs[:, i*seq_length:(i+1)*seq_length] 119 | y = seq_targets[:, i*seq_length:(i+1)*seq_length] 120 | flatten_x = x.reshape((-1,) + inputs.shape[1:]) 121 | flatten_y = y.reshape((-1,) + targets.shape[1:]) 122 | yield flatten_x, flatten_y 123 | 124 | 125 | def iterate_list_batch_seq_minibatches(inputs, targets, batch_size, seq_length): 126 | for idx, each_data in enumerate(zip(inputs, targets)): 127 | each_x, each_y = each_data 128 | seq_x, seq_y = [], [] 129 | for x_batch, y_batch in iterate_seq_minibatches(inputs=each_x, 130 | targets=each_y, 131 | batch_size=1, 132 | seq_length=seq_length, 133 | stride=1): 134 | seq_x.append(x_batch) 135 | seq_y.append(y_batch) 136 | seq_x = np.vstack(seq_x) 137 | seq_x = seq_x.reshape((-1, seq_length) + seq_x.shape[1:]) 138 | seq_y = np.hstack(seq_y) 139 | seq_y = seq_y.reshape((-1, seq_length) + seq_y.shape[1:]) 140 | 141 | for x_batch, y_batch in iterate_batch_seq_minibatches(inputs=seq_x, 142 | targets=seq_y, 143 | batch_size=batch_size, 144 | seq_length=1): 145 | x_batch = x_batch.reshape((-1,) + x_batch.shape[2:]) 146 | y_batch = y_batch.reshape((-1,) + y_batch.shape[2:]) 147 | yield x_batch, y_batch 148 | 149 | 150 | 151 | def DataAugmentation(x, roll_range=0.5, horizontal_flip=True,seed=None): 152 | assert x.shape[1:] == (3000, 1, 1) 153 | if seed is not None: 154 | np.random.seed(seed) 155 | N = x.shape 156 | 157 | if roll_range: 158 | if np.random.random() < 0.5: 159 | tx = np.random.uniform(-roll_range, roll_range) 160 | if roll_range < 1: 161 | tx *= N[1] 162 | x = np.roll(x, int(tx), axis=1) 163 | 164 | if horizontal_flip: 165 | if np.random.random() < 0.5: 166 | x = np.flip(x,axis=1) 167 | return x 168 | 169 | 170 | def sequence_down_sample(data,lable,sequence_length=25, down_rate=5): 171 | lab_dict = Counter(lable) 172 | lab_list = [] 173 | for i in lab_dict: 174 | lab_i = np.array([i, lab_dict[i]]) 175 | lab_list.append(lab_i) 176 | lab_list = np.array(lab_list) 177 | lab_array = lab_list[np.argsort(lab_list[:, 0])] 178 | 179 | data_list = [] 180 | lab_list = [] 181 | # 10 最小控制数 batch_size*down_rate 182 | 183 | min_num = int(lab_array[:, 1].min()/down_rate) 184 | min_num = min_num if min_num>1 else int(lab_array[:, 1].min()) 185 | 186 | for i in range(len(lab_array)): 187 | lab = lab_array[i, 0] 188 | data_idx = np.array(np.where(lable == lab)).reshape(-1) 189 | np.random.shuffle(data_idx) 190 | data_idx = data_idx[0:int(min_num)] 191 | 192 | for j in range(int(min_num)): 193 | N = sequence_length 194 | bias = np.random.randint(int(-N / 2), int(N / 2)) 195 | if data_idx[j] + bias + N > data.shape[0]: 196 | bias = -N 197 | if data_idx[j] + bias < 0: 198 | bias = N 199 | data_ij = data[data_idx[j] + bias:data_idx[j] + bias + N, :] 200 | lab_ij = lable[data_idx[j] + bias:data_idx[j] + bias + N] 201 | data_list.append(data_ij) 202 | lab_list.append(lab_ij) 203 | data_sample = np.concatenate(tuple([data_i for data_i in data_list]), axis=0) 204 | lab_sample = np.concatenate(tuple([lab_i for lab_i in lab_list]), axis=0) 205 | return data_sample,lab_sample -------------------------------------------------------------------------------- /dhedfreader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reader for EDF+ files. 3 | TODO: 4 | - add support for log-transformed channels: 5 | http://www.edfplus.info/specs/edffloat.html and test with 6 | data generated with 7 | http://www.edfplus.info/downloads/software/NeuroLoopGain.zip. 8 | - check annotations with Schalk's Physiobank data. 9 | Copyright (c) 2012 Boris Reuderink. 10 | ''' 11 | 12 | import re, datetime, operator, logging 13 | import numpy as np 14 | from collections import namedtuple 15 | from functools import reduce 16 | 17 | EVENT_CHANNEL = 'EDF Annotations' 18 | log = logging.getLogger(__name__) 19 | 20 | class EDFEndOfData(BaseException): pass 21 | 22 | 23 | def tal(tal_str): 24 | '''Return a list with (onset, duration, annotation) tuples for an EDF+ TAL 25 | stream. 26 | ''' 27 | exp = '(?P[+\-]\d+(?:\.\d*)?)' + \ 28 | '(?:\x15(?P\d+(?:\.\d*)?))?' + \ 29 | '(\x14(?P[^\x00]*))?' + \ 30 | '(?:\x14\x00)' 31 | 32 | def annotation_to_list(annotation): 33 | #return str(annotation, 'utf-8').split('\x14') if annotation else [] 34 | return annotation.split('\x14') if annotation else [] 35 | 36 | def parse(dic): 37 | return ( 38 | float(dic['onset']), 39 | float(dic['duration']) if dic['duration'] else 0., 40 | annotation_to_list(dic['annotation'])) 41 | 42 | return [parse(m.groupdict()) for m in re.finditer(exp, tal_str)] 43 | 44 | 45 | def edf_header(f): 46 | h = {} 47 | assert f.tell() == 0 # check file position 48 | assert f.read(8) == '0 ' 49 | 50 | # recording info) 51 | h['local_subject_id'] = f.read(80).strip() 52 | h['local_recording_id'] = f.read(80).strip() 53 | 54 | # parse timestamp 55 | (day, month, year) = [int(x) for x in re.findall('(\d+)', f.read(8))] 56 | (hour, minute, sec)= [int(x) for x in re.findall('(\d+)', f.read(8))] 57 | h['date_time'] = str(datetime.datetime(year + 2000, month, day, 58 | hour, minute, sec)) 59 | 60 | # misc 61 | header_nbytes = int(f.read(8)) 62 | subtype = f.read(44)[:5] 63 | h['EDF+'] = subtype in ['EDF+C', 'EDF+D'] 64 | h['contiguous'] = subtype != 'EDF+D' 65 | h['n_records'] = int(f.read(8)) 66 | h['record_length'] = float(f.read(8)) # in seconds 67 | nchannels = h['n_channels'] = int(f.read(4)) 68 | 69 | # read channel info 70 | channels = list(range(h['n_channels'])) 71 | h['label'] = [f.read(16).strip() for n in channels] 72 | h['transducer_type'] = [f.read(80).strip() for n in channels] 73 | h['units'] = [f.read(8).strip() for n in channels] 74 | h['physical_min'] = np.asarray([float(f.read(8)) for n in channels]) 75 | h['physical_max'] = np.asarray([float(f.read(8)) for n in channels]) 76 | h['digital_min'] = np.asarray([float(f.read(8)) for n in channels]) 77 | h['digital_max'] = np.asarray([float(f.read(8)) for n in channels]) 78 | h['prefiltering'] = [f.read(80).strip() for n in channels] 79 | h['n_samples_per_record'] = [int(f.read(8)) for n in channels] 80 | f.read(32 * nchannels) # reserved 81 | 82 | assert f.tell() == header_nbytes 83 | return h 84 | 85 | 86 | class BaseEDFReader: 87 | def __init__(self, file): 88 | self.file = file 89 | 90 | 91 | def read_header(self): 92 | self.header = h = edf_header(self.file) 93 | 94 | # calculate ranges for rescaling 95 | self.dig_min = h['digital_min'] 96 | self.phys_min = h['physical_min'] 97 | phys_range = h['physical_max'] - h['physical_min'] 98 | dig_range = h['digital_max'] - h['digital_min'] 99 | assert np.all(phys_range > 0) 100 | assert np.all(dig_range > 0) 101 | self.gain = phys_range / dig_range 102 | 103 | 104 | def read_raw_record(self): 105 | '''Read a record with data and return a list containing arrays with raw 106 | bytes. 107 | ''' 108 | result = [] 109 | for nsamp in self.header['n_samples_per_record']: 110 | samples = self.file.read(nsamp * 2) 111 | if len(samples) != nsamp * 2: 112 | raise EDFEndOfData 113 | result.append(samples) 114 | return result 115 | 116 | 117 | def convert_record(self, raw_record): 118 | '''Convert a raw record to a (time, signals, events) tuple based on 119 | information in the header. 120 | ''' 121 | h = self.header 122 | dig_min, phys_min, gain = self.dig_min, self.phys_min, self.gain 123 | time = float('nan') 124 | signals = [] 125 | events = [] 126 | for (i, samples) in enumerate(raw_record): 127 | if h['label'][i] == EVENT_CHANNEL: 128 | ann = tal(samples) 129 | time = ann[0][0] 130 | events.extend(ann[1:]) 131 | # print(i, samples) 132 | # exit() 133 | else: 134 | # 2-byte little-endian integers 135 | dig = np.fromstring(samples, '= sequence_length(b)) 133 | ? (zeros(cell.output_size), states(b, sequence_length(b) - 1)) 134 | : cell(input(b, t), state(b, t - 1)) 135 | ``` 136 | Args: 137 | cell: An instance of RNNCell. 138 | inputs: A length T list of inputs, each a `Tensor` of shape 139 | `[batch_size, input_size]`, or a nested tuple of such elements. 140 | initial_state: (optional) An initial state for the RNN. 141 | If `cell.state_size` is an integer, this must be 142 | a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. 143 | If `cell.state_size` is a tuple, this should be a tuple of 144 | tensors having shapes `[batch_size, s] for s in cell.state_size`. 145 | dtype: (optional) The data type for the initial state and expected output. 146 | Required if initial_state is not provided or RNN state has a heterogeneous 147 | dtype. 148 | sequence_length: Specifies the length of each sequence in inputs. 149 | An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`. 150 | scope: VariableScope for the created subgraph; defaults to "RNN". 151 | Returns: 152 | A pair (outputs, state) where: 153 | - outputs is a length T list of outputs (one for each input), or a nested 154 | tuple of such elements. 155 | - state is the final state 156 | Raises: 157 | TypeError: If `cell` is not an instance of RNNCell. 158 | ValueError: If `inputs` is `None` or an empty list, or if the input depth 159 | (column size) cannot be inferred from inputs via shape inference. 160 | """ 161 | 162 | if not isinstance(cell, tf.compat.v1.nn.rnn_cell.RNNCell): 163 | raise TypeError("cell must be an instance of RNNCell") 164 | if not nest.is_sequence(inputs): 165 | raise TypeError("inputs must be a sequence") 166 | if not inputs: 167 | raise ValueError("inputs must not be empty") 168 | 169 | outputs = [] 170 | states = [] 171 | # Create a new scope in which the caching device is either 172 | # determined by the parent scope, or is set to place the cached 173 | # Variable using the same placement as for the rest of the RNN. 174 | with vs.variable_scope(scope or "RNN") as varscope: 175 | if varscope.caching_device is None: 176 | varscope.set_caching_device(lambda op: op.device) 177 | 178 | # Obtain the first sequence of the input 179 | first_input = inputs 180 | while nest.is_sequence(first_input): 181 | first_input = first_input[0] 182 | 183 | # Temporarily avoid EmbeddingWrapper and seq2seq badness 184 | # TODO(lukaszkaiser): remove EmbeddingWrapper 185 | if first_input.get_shape().ndims != 1: 186 | 187 | input_shape = first_input.get_shape().with_rank_at_least(2) 188 | fixed_batch_size = input_shape[0] 189 | 190 | flat_inputs = nest.flatten(inputs) 191 | for flat_input in flat_inputs: 192 | input_shape = flat_input.get_shape().with_rank_at_least(2) 193 | batch_size, input_size = input_shape[0], input_shape[1:] 194 | fixed_batch_size.merge_with(batch_size) 195 | for i, size in enumerate(input_size): 196 | if size.value is None: 197 | raise ValueError( 198 | "Input size (dimension %d of inputs) must be accessible via " 199 | "shape inference, but saw value None." % i) 200 | else: 201 | fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0] 202 | 203 | if fixed_batch_size.value: 204 | batch_size = fixed_batch_size.value 205 | else: 206 | batch_size = array_ops.shape(first_input)[0] 207 | if initial_state is not None: 208 | state = initial_state 209 | else: 210 | if not dtype: 211 | raise ValueError("If no initial_state is provided, " 212 | "dtype must be specified") 213 | state = cell.zero_state(batch_size, dtype) 214 | 215 | if sequence_length is not None: # Prepare variables 216 | sequence_length = ops.convert_to_tensor( 217 | sequence_length, name="sequence_length") 218 | if sequence_length.get_shape().ndims not in (None, 1): 219 | raise ValueError( 220 | "sequence_length must be a vector of length batch_size") 221 | def _create_zero_output(output_size): 222 | # convert int to TensorShape if necessary 223 | size = _state_size_with_prefix(output_size, prefix=[batch_size]) 224 | output = array_ops.zeros( 225 | array_ops.pack(size), _infer_state_dtype(dtype, state)) 226 | shape = _state_size_with_prefix( 227 | output_size, prefix=[fixed_batch_size.value]) 228 | output.set_shape(tensor_shape.TensorShape(shape)) 229 | return output 230 | 231 | output_size = cell.output_size 232 | flat_output_size = nest.flatten(output_size) 233 | flat_zero_output = tuple( 234 | _create_zero_output(size) for size in flat_output_size) 235 | zero_output = nest.pack_sequence_as(structure=output_size, 236 | flat_sequence=flat_zero_output) 237 | 238 | sequence_length = math_ops.to_int32(sequence_length) 239 | min_sequence_length = math_ops.reduce_min(sequence_length) 240 | max_sequence_length = math_ops.reduce_max(sequence_length) 241 | 242 | for time, input_ in enumerate(inputs): 243 | if time > 0: varscope.reuse_variables() 244 | # pylint: disable=cell-var-from-loop 245 | call_cell = lambda: cell(input_, state) 246 | # pylint: enable=cell-var-from-loop 247 | if sequence_length is not None: 248 | (output, state) = _rnn_step( 249 | time=time, 250 | sequence_length=sequence_length, 251 | min_sequence_length=min_sequence_length, 252 | max_sequence_length=max_sequence_length, 253 | zero_output=zero_output, 254 | state=state, 255 | call_cell=call_cell, 256 | state_size=cell.state_size) 257 | else: 258 | (output, state) = call_cell() 259 | 260 | outputs.append(output) 261 | states.append(state) 262 | 263 | return (outputs, state, states) 264 | 265 | 266 | def custom_bidirectional_rnn(cell_fw, cell_bw, inputs, 267 | initial_state_fw=None, initial_state_bw=None, 268 | dtype=None, sequence_length=None, scope=None): 269 | """Creates a bidirectional recurrent neural network. 270 | Similar to the unidirectional case above (rnn) but takes input and builds 271 | independent forward and backward RNNs with the final forward and backward 272 | outputs depth-concatenated, such that the output will have the format 273 | [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of 274 | forward and backward cell must match. The initial state for both directions 275 | is zero by default (but can be set optionally) and no intermediate states are 276 | ever returned -- the network is fully unrolled for the given (passed in) 277 | length(s) of the sequence(s) or completely unrolled if length(s) is not given. 278 | Args: 279 | cell_fw: An instance of RNNCell, to be used for forward direction. 280 | cell_bw: An instance of RNNCell, to be used for backward direction. 281 | inputs: A length T list of inputs, each a tensor of shape 282 | [batch_size, input_size], or a nested tuple of such elements. 283 | initial_state_fw: (optional) An initial state for the forward RNN. 284 | This must be a tensor of appropriate type and shape 285 | `[batch_size, cell_fw.state_size]`. 286 | If `cell_fw.state_size` is a tuple, this should be a tuple of 287 | tensors having shapes `[batch_size, s] for s in cell_fw.state_size`. 288 | initial_state_bw: (optional) Same as for `initial_state_fw`, but using 289 | the corresponding properties of `cell_bw`. 290 | dtype: (optional) The data type for the initial state. Required if 291 | either of the initial states are not provided. 292 | sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 293 | containing the actual lengths for each of the sequences. 294 | scope: VariableScope for the created subgraph; defaults to "BiRNN" 295 | Returns: 296 | A tuple (outputs, output_state_fw, output_state_bw) where: 297 | outputs is a length `T` list of outputs (one for each input), which 298 | are depth-concatenated forward and backward outputs. 299 | output_state_fw is the final state of the forward rnn. 300 | output_state_bw is the final state of the backward rnn. 301 | Raises: 302 | TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 303 | ValueError: If inputs is None or an empty list. 304 | """ 305 | 306 | if not isinstance(cell_fw, tf.compat.v1.nn.rnn_cell.RNNCell): 307 | raise TypeError("cell_fw must be an instance of RNNCell") 308 | if not isinstance(cell_bw, tf.compat.v1.nn.rnn_cell.RNNCell): 309 | raise TypeError("cell_bw must be an instance of RNNCell") 310 | if not nest.is_sequence(inputs): 311 | raise TypeError("inputs must be a sequence") 312 | if not inputs: 313 | raise ValueError("inputs must not be empty") 314 | 315 | with vs.variable_scope(scope or "bidirectional_rnn"): 316 | # Forward direction 317 | with vs.variable_scope("fw") as fw_scope: 318 | output_fw, output_state_fw, fw_states = custom_rnn( 319 | cell_fw, inputs, initial_state_fw, dtype, 320 | sequence_length, scope=fw_scope 321 | ) 322 | 323 | # Backward direction 324 | with vs.variable_scope("bw") as bw_scope: 325 | reversed_inputs = _reverse_seq(inputs, sequence_length) 326 | tmp, output_state_bw, tmp_states = custom_rnn( 327 | cell_bw, reversed_inputs, initial_state_bw, 328 | dtype, sequence_length, scope=bw_scope 329 | ) 330 | 331 | output_bw = _reverse_seq(tmp, sequence_length) 332 | bw_states = _reverse_seq(tmp_states, sequence_length) 333 | 334 | # Concat each of the forward/backward outputs 335 | flat_output_fw = nest.flatten(output_fw) 336 | flat_output_bw = nest.flatten(output_bw) 337 | 338 | flat_outputs = tuple(array_ops.concat(values=[fw, bw], axis=1) 339 | for fw, bw in zip(flat_output_fw, flat_output_bw)) 340 | 341 | outputs = nest.pack_sequence_as(structure=output_fw, 342 | flat_sequence=flat_outputs) 343 | 344 | return (outputs, output_state_fw, output_state_bw, fw_states, bw_states) 345 | 346 | 347 | class CustomCCRRSleepNet(CCRRSleepNet): 348 | 349 | def __init__( 350 | self, 351 | batch_size, 352 | input_dims, 353 | n_classes, 354 | seq_length, 355 | n_rnn_layers, 356 | return_last, 357 | is_train, 358 | reuse_params, 359 | use_dropout_feature, 360 | use_dropout_sequence, 361 | name="ccrrsleepnet" 362 | ): 363 | super(CCRRSleepNet, self).__init__( 364 | batch_size=batch_size, 365 | input_dims=input_dims, 366 | n_classes=n_classes, 367 | is_train=is_train, 368 | reuse_params=reuse_params, 369 | use_dropout=use_dropout_feature, 370 | name=name 371 | ) 372 | 373 | self.seq_length = seq_length 374 | self.n_rnn_layers = n_rnn_layers 375 | self.return_last = return_last 376 | 377 | self.use_dropout_sequence = use_dropout_sequence 378 | 379 | def build_model(self, input_var): 380 | # Create a network with superclass method 381 | network = super(CCRRSleepNet, self).build_model( 382 | input_var=self.input_var 383 | ) 384 | 385 | # Residual (or shortcut) connection 386 | output_conns = [] 387 | 388 | # Fully-connected to select some part of the output to add with the output from bi-directional LSTM 389 | name = "l{}_fc".format(self.layer_idx) 390 | with tf.compat.v1.variable_scope(name) as scope: 391 | output = fc(name="fc", input_var=network, n_hiddens=1024, bias=None, wd=0) 392 | output = batch_norm_new(name="bn", input_var=output, is_train=self.is_train) 393 | output = mish(output) 394 | self.activations.append((name, output)) 395 | self.layer_idx += 1 396 | output_conns.append(output) 397 | # output_conns.append(network) 398 | 399 | ###################################################################### 400 | 401 | # Reshape the input from (batch_size * seq_length, input_dim) to 402 | # (batch_size, seq_length, input_dim) 403 | name = "l{}_reshape_seq".format(self.layer_idx) 404 | input_dim = network.get_shape()[-1].value 405 | seq_input = tf.reshape(network, 406 | shape=[-1, self.seq_length, input_dim], 407 | name=name) 408 | assert self.batch_size == seq_input.get_shape()[0].value 409 | self.activations.append((name, seq_input)) 410 | self.layer_idx += 1 411 | 412 | # Bidirectional LSTM network 413 | name = "l{}_bi_lstm".format(self.layer_idx) 414 | hidden_size = 512 # will output 1024 (512 forward, 512 backward) 415 | with tf.compat.v1.variable_scope(name) as scope: 416 | 417 | def lstm_cell(): 418 | cell = tf.compat.v1.nn.rnn_cell.LSTMCell(hidden_size, 419 | use_peepholes=True, 420 | state_is_tuple=True, 421 | reuse=tf.compat.v1.get_variable_scope().reuse) 422 | if self.use_dropout_sequence: 423 | keep_prob = 0.8 if self.is_train else 1.0 424 | cell = tf.compat.v1.nn.rnn_cell.DropoutWrapper( 425 | cell, 426 | output_keep_prob=keep_prob 427 | ) 428 | 429 | return cell 430 | 431 | fw_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(self.n_rnn_layers)], state_is_tuple = True) 432 | bw_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(self.n_rnn_layers)], state_is_tuple = True) 433 | 434 | # Initial state of RNN 435 | self.fw_initial_state = fw_cell.zero_state(self.batch_size, tf.float32) 436 | self.bw_initial_state = bw_cell.zero_state(self.batch_size, tf.float32) 437 | 438 | # Feedforward to MultiRNNCell 439 | list_rnn_inputs = tf.unstack(seq_input, axis=1) 440 | outputs, fw_state, bw_state, fw_states, bw_states = custom_bidirectional_rnn( 441 | cell_fw=fw_cell, 442 | cell_bw=bw_cell, 443 | inputs=list_rnn_inputs, 444 | initial_state_fw=self.fw_initial_state, 445 | initial_state_bw=self.bw_initial_state 446 | ) 447 | 448 | if self.return_last: 449 | network = outputs[-1] 450 | else: 451 | network = tf.reshape(tf.concat(axis=1, values=outputs), [-1, hidden_size*2], 452 | name=name) 453 | self.activations.append((name, network)) 454 | self.layer_idx +=1 455 | 456 | self.fw_final_state = fw_state 457 | self.bw_final_state = bw_state 458 | 459 | self.fw_states = fw_states 460 | self.bw_states = bw_states 461 | 462 | # Append output 463 | output_conns.append(network) 464 | 465 | ###################################################################### 466 | 467 | # Add 468 | name = "l{}_add".format(self.layer_idx) 469 | network = tf.add_n(output_conns, name=name) 470 | self.activations.append((name, network)) 471 | self.layer_idx += 1 472 | 473 | # Dropout 474 | if self.use_dropout_sequence: 475 | name = "l{}_dropout".format(self.layer_idx) 476 | if self.is_train: 477 | network = tf.nn.dropout(network, keep_prob=0.8, name=name) 478 | else: 479 | network = tf.nn.dropout(network, keep_prob=1.0, name=name) 480 | self.activations.append((name, network)) 481 | self.layer_idx += 1 482 | 483 | return network 484 | 485 | 486 | def custom_run_epoch( 487 | sess, 488 | network, 489 | inputs, 490 | targets, 491 | train_op, 492 | is_train, 493 | output_dir, 494 | subject_idx 495 | ): 496 | start_time = time.time() 497 | y = [] 498 | y_true = [] 499 | all_fw_memory_cells = [] 500 | all_bw_memory_cells = [] 501 | total_loss, n_batches = 0.0, 0 502 | for sub_f_idx, each_data in enumerate(zip(inputs, targets)): 503 | each_x, each_y = each_data 504 | 505 | # # Initialize state of LSTM - Unidirectional LSTM 506 | # state = sess.run(network.initial_state) 507 | 508 | # Initialize state of LSTM - Bidirectional LSTM 509 | fw_state = sess.run(network.fw_initial_state) 510 | bw_state = sess.run(network.bw_initial_state) 511 | 512 | # Prepare storage for memory cells 513 | n_all_data = len(each_x) 514 | extra = n_all_data % network.seq_length 515 | n_data = n_all_data - extra 516 | cell_size = 512 517 | fw_memory_cells = np.zeros((n_data, network.n_rnn_layers, cell_size)) 518 | bw_memory_cells = np.zeros((n_data, network.n_rnn_layers, cell_size)) 519 | seq_idx = 0 520 | 521 | # Store prediction and actual stages of each patient 522 | each_y_true = [] 523 | each_y_pred = [] 524 | 525 | for x_batch, y_batch in iterate_batch_seq_minibatches(inputs=each_x, 526 | targets=each_y, 527 | batch_size=network.batch_size, 528 | seq_length=network.seq_length): 529 | feed_dict = { 530 | network.input_var: x_batch, 531 | network.target_var: y_batch 532 | } 533 | 534 | # Unidirectional LSTM 535 | # for i, (c, h) in enumerate(network.initial_state): 536 | # feed_dict[c] = state[i].c 537 | # feed_dict[h] = state[i].h 538 | 539 | # _, loss_value, y_pred, state = sess.run( 540 | # [train_op, network.loss_op, network.pred_op, network.final_state], 541 | # feed_dict=feed_dict 542 | # ) 543 | 544 | for i, (c, h) in enumerate(network.fw_initial_state): 545 | feed_dict[c] = fw_state[i].c 546 | feed_dict[h] = fw_state[i].h 547 | 548 | for i, (c, h) in enumerate(network.bw_initial_state): 549 | feed_dict[c] = bw_state[i].c 550 | feed_dict[h] = bw_state[i].h 551 | 552 | _, loss_value, y_pred, fw_state, bw_state = sess.run( 553 | [train_op, network.loss_op, network.pred_op, network.fw_final_state, network.bw_final_state], 554 | feed_dict=feed_dict 555 | ) 556 | 557 | # Extract memory cells 558 | fw_states = sess.run(network.fw_states, feed_dict=feed_dict) 559 | bw_states = sess.run(network.bw_states, feed_dict=feed_dict) 560 | offset_idx = seq_idx * network.seq_length 561 | for s_idx in range(network.seq_length): 562 | for r_idx in range(network.n_rnn_layers): 563 | fw_memory_cells[offset_idx + s_idx][r_idx] = np.squeeze(fw_states[s_idx][r_idx].c) 564 | bw_memory_cells[offset_idx + s_idx][r_idx] = np.squeeze(bw_states[s_idx][r_idx].c) 565 | seq_idx += 1 566 | each_y_true.extend(y_batch) 567 | each_y_pred.extend(y_pred) 568 | 569 | total_loss += loss_value 570 | n_batches += 1 571 | 572 | # Check the loss value 573 | # assert not np.isnan(loss_value), \ 574 | # "Model diverged with loss = NaN" 575 | 576 | all_fw_memory_cells.append(fw_memory_cells) 577 | all_bw_memory_cells.append(bw_memory_cells) 578 | y.append(each_y_pred) 579 | y_true.append(each_y_true) 580 | 581 | # Save memory cells and predictions 582 | save_dict = { 583 | "fw_memory_cells": fw_memory_cells, 584 | "bw_memory_cells": bw_memory_cells, 585 | "y_true": y_true, 586 | "y_pred": y 587 | } 588 | save_path = os.path.join( 589 | output_dir, 590 | "output_subject{}.npz".format(subject_idx) 591 | ) 592 | np.savez(save_path, **save_dict) 593 | print("Saved outputs to {}".format(save_path)) 594 | 595 | duration = time.time() - start_time 596 | total_loss /= n_batches 597 | total_y_pred = np.hstack(y) 598 | total_y_true = np.hstack(y_true) 599 | 600 | return total_y_true, total_y_pred, total_loss, duration 601 | 602 | 603 | def predict( 604 | data_dir, 605 | model_dir, 606 | output_dir, 607 | n_subjects, 608 | n_subjects_per_fold 609 | ): 610 | # Ground truth and predictions 611 | y_true = [] 612 | y_pred = [] 613 | 614 | # The model will be built into the default Graph 615 | with tf.Graph().as_default(), tf.compat.v1.Session() as sess: 616 | # Build the network 617 | valid_net = CustomCCRRSleepNet( 618 | batch_size=1, 619 | input_dims=EPOCH_SEC_LEN*100, 620 | n_classes=NUM_CLASSES, 621 | seq_length=25, 622 | n_rnn_layers=2, 623 | return_last=False, 624 | is_train=False, 625 | reuse_params=False, 626 | use_dropout_feature=True, 627 | use_dropout_sequence=True 628 | ) 629 | 630 | # Initialize parameters 631 | valid_net.init_ops() 632 | 633 | for subject_idx in range(n_subjects): 634 | fold_idx = subject_idx // n_subjects_per_fold 635 | 636 | checkpoint_path = os.path.join( 637 | model_dir, 638 | "fold{}".format(fold_idx), 639 | "ccrrsleepnet" 640 | ) 641 | 642 | # Restore the trained model 643 | saver = tf.compat.v1.train.Saver() 644 | saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) 645 | print("Model restored from: {}\n".format(tf.train.latest_checkpoint(checkpoint_path))) 646 | 647 | # Load testing data 648 | Seqdata = SeqDataLoader(data_dir, n_subjects, fold_idx) 649 | # SeqDataLoader = SeqDataLoader(data_dir,n_subjects,subject_idx) 650 | x, y = Seqdata.load_test_data() 651 | # x, y = SeqDataLoader.load_subject_data( 652 | # data_dir=data_dir, 653 | # subject_idx=subject_idx 654 | # ) 655 | 656 | # Loop each epoch 657 | print("[{}] Predicting ...\n".format(datetime.now())) 658 | 659 | # Evaluate the model on the subject data 660 | y_true_, y_pred_, loss, duration = \ 661 | custom_run_epoch( 662 | sess=sess, network=valid_net, 663 | inputs=x, targets=y, 664 | train_op=tf.no_op(), 665 | is_train=False, 666 | output_dir=output_dir, 667 | subject_idx=subject_idx 668 | ) 669 | n_examples = len(y_true_) 670 | cm_ = confusion_matrix(y_true_, y_pred_) 671 | acc_ = np.mean(y_true_ == y_pred_) 672 | mf1_ = f1_score(y_true_, y_pred_, average="macro") 673 | 674 | # Report performance 675 | print_performance( 676 | sess, valid_net.name, 677 | n_examples, duration, loss, 678 | cm_, acc_, mf1_ 679 | ) 680 | 681 | y_true.extend(y_true_) 682 | y_pred.extend(y_pred_) 683 | 684 | # Overall performance 685 | print("[{}] Overall prediction performance\n".format(datetime.now())) 686 | y_true = np.asarray(y_true) 687 | y_pred = np.asarray(y_pred) 688 | n_examples = len(y_true) 689 | cm = confusion_matrix(y_true, y_pred) 690 | acc = np.mean(y_true == y_pred) 691 | mf1 = f1_score(y_true, y_pred, average="macro") 692 | print(( 693 | "n={}, acc={:.3f}, f1={:.3f}".format( 694 | n_examples, acc, mf1 695 | ) 696 | )) 697 | print(cm) 698 | 699 | 700 | def main(argv=None): 701 | # # Makes the random numbers predictable 702 | # np.random.seed(0) 703 | # tf.set_random_seed(0) 704 | 705 | # Output dir 706 | if not os.path.exists(FLAGS.output_dir): 707 | os.makedirs(FLAGS.output_dir) 708 | 709 | n_subjects = 20 710 | n_subjects_per_fold = 1 711 | predict( 712 | data_dir=FLAGS.data_dir, 713 | model_dir=FLAGS.model_dir, 714 | output_dir=FLAGS.output_dir, 715 | n_subjects=n_subjects, 716 | n_subjects_per_fold=n_subjects_per_fold 717 | ) 718 | 719 | 720 | if __name__ == "__main__": 721 | tf.compat.v1.app.run() 722 | -------------------------------------------------------------------------------- /prepare_physionet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import math 4 | import ntpath 5 | import os 6 | import shutil 7 | import urllib.request, urllib.parse, urllib.error 8 | import urllib.request, urllib.error, urllib.parse 9 | 10 | from datetime import datetime 11 | 12 | import numpy as np 13 | 14 | from mne import Epochs, pick_types, find_events 15 | from mne.io import concatenate_raws, read_raw_edf 16 | 17 | import dhedfreader 18 | 19 | 20 | # Label values 21 | W = 0 22 | N1 = 1 23 | N2 = 2 24 | N3 = 3 25 | REM = 4 26 | UNKNOWN = 5 27 | 28 | stage_dict = { 29 | "W": W, 30 | "N1": N1, 31 | "N2": N2, 32 | "N3": N3, 33 | "REM": REM, 34 | "UNKNOWN": UNKNOWN 35 | } 36 | 37 | class_dict = { 38 | 0: "W", 39 | 1: "N1", 40 | 2: "N2", 41 | 3: "N3", 42 | 4: "REM", 43 | 5: "UNKNOWN" 44 | } 45 | 46 | ann2label = { 47 | "Sleep stage W": 0, 48 | "Sleep stage 1": 1, 49 | "Sleep stage 2": 2, 50 | "Sleep stage 3": 3, 51 | "Sleep stage 4": 3, 52 | "Sleep stage R": 4, 53 | "Sleep stage ?": 5, 54 | "Movement time": 5 55 | } 56 | 57 | EPOCH_SEC_SIZE = 30 58 | 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument("--data_dir", type=str, default="../sleep-edf-database/sleep-cassette", 63 | help="File path to the CSV or NPY file that contains walking data.") 64 | parser.add_argument("--output_dir", type=str, default="../data/data-153/eeg_pz_oz", 65 | help="Directory where to save outputs.") 66 | # parser.add_argument("--select_ch", type=str, default="EEG Fpz-Cz", 67 | parser.add_argument("--select_ch", type=str, default="EEG Pz-Oz", 68 | help="File path to the trained model used to estimate walking speeds.") 69 | args = parser.parse_args() 70 | 71 | # Output dir 72 | if not os.path.exists(args.output_dir): 73 | os.makedirs(args.output_dir) 74 | else: 75 | shutil.rmtree(args.output_dir) 76 | os.makedirs(args.output_dir) 77 | 78 | # Select channel 79 | select_ch = args.select_ch 80 | 81 | # Read raw and annotation EDF files 82 | psg_fnames = glob.glob(os.path.join(args.data_dir, "*PSG.edf")) 83 | ann_fnames = glob.glob(os.path.join(args.data_dir, "*Hypnogram.edf")) 84 | psg_fnames.sort() 85 | ann_fnames.sort() 86 | psg_fnames = np.asarray(psg_fnames) 87 | ann_fnames = np.asarray(ann_fnames) 88 | 89 | for i in range(len(psg_fnames)): 90 | # if not "ST7171J0-PSG.edf" in psg_fnames[i]: 91 | # continue 92 | 93 | raw = read_raw_edf(psg_fnames[i], preload=True, stim_channel=None) 94 | sampling_rate = raw.info['sfreq'] 95 | raw_ch_df = raw.to_data_frame(scaling_time=100.0)[select_ch] 96 | raw_ch_df = raw_ch_df.to_frame() 97 | raw_ch_df.set_index(np.arange(len(raw_ch_df))) 98 | 99 | # Get raw header 100 | f = open(psg_fnames[i], 'r', encoding='iso-8859-1') 101 | reader_raw = dhedfreader.BaseEDFReader(f) 102 | reader_raw.read_header() 103 | h_raw = reader_raw.header 104 | f.close() 105 | raw_start_dt = datetime.strptime(h_raw['date_time'], "%Y-%m-%d %H:%M:%S") 106 | 107 | # Read annotation and its header 108 | f = open(ann_fnames[i], 'r', encoding='iso-8859-1') 109 | reader_ann = dhedfreader.BaseEDFReader(f) 110 | reader_ann.read_header() 111 | h_ann = reader_ann.header 112 | _, _, ann = list(zip(*reader_ann.records())) 113 | f.close() 114 | ann_start_dt = datetime.strptime(h_ann['date_time'], "%Y-%m-%d %H:%M:%S") 115 | 116 | # Assert that raw and annotation files start at the same time 117 | assert raw_start_dt == ann_start_dt 118 | 119 | # Generate label and remove indices 120 | remove_idx = [] # indicies of the data that will be removed 121 | labels = [] # indicies of the data that have labels 122 | label_idx = [] 123 | for a in ann[0]: 124 | onset_sec, duration_sec, ann_char = a 125 | ann_str = "".join(ann_char) 126 | label = ann2label[ann_str] 127 | if label != UNKNOWN: 128 | if duration_sec % EPOCH_SEC_SIZE != 0: 129 | raise Exception("Something wrong") 130 | duration_epoch = int(duration_sec / EPOCH_SEC_SIZE) 131 | label_epoch = np.ones(duration_epoch, dtype=np.int) * label 132 | labels.append(label_epoch) 133 | idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int) 134 | label_idx.append(idx) 135 | 136 | print("Include onset:{}, duration:{}, label:{} ({})".format( 137 | onset_sec, duration_sec, label, ann_str 138 | )) 139 | else: 140 | idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int) 141 | remove_idx.append(idx) 142 | 143 | print("Remove onset:{}, duration:{}, label:{} ({})".format( 144 | onset_sec, duration_sec, label, ann_str 145 | )) 146 | labels = np.hstack(labels) 147 | 148 | print("before remove unwanted: {}".format(np.arange(len(raw_ch_df)).shape)) 149 | if len(remove_idx) > 0: 150 | remove_idx = np.hstack(remove_idx) 151 | select_idx = np.setdiff1d(np.arange(len(raw_ch_df)), remove_idx) 152 | else: 153 | select_idx = np.arange(len(raw_ch_df)) 154 | print("after remove unwanted: {}".format(select_idx.shape)) 155 | 156 | # Select only the data with labels 157 | print("before intersect label: {}".format(select_idx.shape)) 158 | label_idx = np.hstack(label_idx) 159 | select_idx = np.intersect1d(select_idx, label_idx) 160 | print("after intersect label: {}".format(select_idx.shape)) 161 | 162 | # Remove extra index 163 | if len(label_idx) > len(select_idx): 164 | print("before remove extra labels: {}, {}".format(select_idx.shape, labels.shape)) 165 | extra_idx = np.setdiff1d(label_idx, select_idx) 166 | # Trim the tail 167 | if np.all(extra_idx > select_idx[-1]): 168 | # n_trims = len(select_idx) % int(EPOCH_SEC_SIZE * sampling_rate) 169 | # n_label_trims = int(math.ceil(n_trims / (EPOCH_SEC_SIZE * sampling_rate))) 170 | n_label_trims = int(math.ceil(len(extra_idx) / (EPOCH_SEC_SIZE * sampling_rate))) 171 | if n_label_trims!=0: 172 | # select_idx = select_idx[:-n_trims] 173 | labels = labels[:-n_label_trims] 174 | print("after remove extra labels: {}, {}".format(select_idx.shape, labels.shape)) 175 | 176 | # Remove movement and unknown stages if any 177 | raw_ch = raw_ch_df.values[select_idx] 178 | 179 | # Verify that we can split into 30-s epochs 180 | if len(raw_ch) % (EPOCH_SEC_SIZE * sampling_rate) != 0: 181 | raise Exception("Something wrong") 182 | n_epochs = len(raw_ch) / (EPOCH_SEC_SIZE * sampling_rate) 183 | 184 | # Get epochs and their corresponding labels 185 | x = np.asarray(np.split(raw_ch, n_epochs)).astype(np.float32) 186 | y = labels.astype(np.int32) 187 | 188 | assert len(x) == len(y) 189 | 190 | # Select on sleep periods 191 | w_edge_mins = 30 192 | nw_idx = np.where(y != stage_dict["W"])[0] 193 | start_idx = nw_idx[0] - (w_edge_mins * 2) 194 | end_idx = nw_idx[-1] + (w_edge_mins * 2) 195 | if start_idx < 0: start_idx = 0 196 | if end_idx >= len(y): end_idx = len(y) - 1 197 | select_idx = np.arange(start_idx, end_idx+1) 198 | print(("Data before selection: {}, {}".format(x.shape, y.shape))) 199 | x = x[select_idx] 200 | y = y[select_idx] 201 | print(("Data after selection: {}, {}".format(x.shape, y.shape))) 202 | 203 | # Save 204 | filename = ntpath.basename(psg_fnames[i]).replace("-PSG.edf", ".npz") 205 | save_dict = { 206 | "x": x, 207 | "y": y, 208 | "fs": sampling_rate, 209 | "ch_label": select_ch, 210 | "header_raw": h_raw, 211 | "header_annotation": h_ann, 212 | } 213 | np.savez(os.path.join(args.output_dir, filename), **save_dict) 214 | 215 | print("\n=======================================\n") 216 | 217 | 218 | if __name__ == "__main__": 219 | main() 220 | -------------------------------------------------------------------------------- /summary.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | ''' 4 | https://github.com/akaraspt/deepsleepnet 5 | Copyright 2017 Akara Supratak and Hao Dong. All rights reserved. 6 | ''' 7 | 8 | import argparse 9 | import os 10 | import re 11 | 12 | import numpy as np 13 | 14 | from sklearn.metrics import cohen_kappa_score 15 | from sklearn.metrics import confusion_matrix, f1_score 16 | 17 | from ccrrsleep.sleep_stage import W, N1, N2, N3, REM 18 | 19 | 20 | def print_performance(cm): 21 | tp = np.diagonal(cm).astype(np.float) 22 | tpfp = np.sum(cm, axis=0).astype(np.float) # sum of each col 23 | tpfn = np.sum(cm, axis=1).astype(np.float) # sum of each row 24 | acc = np.sum(tp) / np.sum(cm) 25 | precision = tp / tpfp 26 | recall = tp / tpfn 27 | f1 = (2 * precision * recall) / (precision + recall) 28 | mf1 = np.mean(f1) 29 | 30 | 31 | print("Sample: {}".format(np.sum(cm))) 32 | print("W: {}".format(tpfn[W])) 33 | print("N1: {}".format(tpfn[N1])) 34 | print("N2: {}".format(tpfn[N2])) 35 | print("N3: {}".format(tpfn[N3])) 36 | print("REM: {}".format(tpfn[REM])) 37 | print("Confusion matrix:") 38 | print(cm) 39 | print("Precision: {}".format(precision)) 40 | print("Recall: {}".format(recall)) 41 | print("F1: {}".format(f1)) 42 | print("Overall accuracy: {}".format(acc)) 43 | print("Macro-F1 accuracy: {}".format(mf1)) 44 | 45 | 46 | 47 | def perf_overall(data_dir): 48 | # Remove non-output files, and perform ascending sort 49 | allfiles = os.listdir(data_dir) 50 | outputfiles = [] 51 | for idx, f in enumerate(allfiles): 52 | if re.match("^output_.+\d+\.npz", f): 53 | outputfiles.append(os.path.join(data_dir, f)) 54 | outputfiles.sort() 55 | 56 | y_true = [] 57 | y_pred = [] 58 | for fpath in outputfiles: 59 | with np.load(fpath,allow_pickle=True) as f: 60 | print((f["y_true"].shape)) 61 | if len(f["y_true"].shape) == 1: 62 | if len(f["y_true"]) < 10: 63 | f_y_true = np.hstack(f["y_true"]) 64 | f_y_pred = np.hstack(f["y_pred"]) 65 | else: 66 | f_y_true = f["y_true"] 67 | f_y_pred = f["y_pred"] 68 | else: 69 | f_y_true = f["y_true"].flatten() 70 | f_y_pred = f["y_pred"].flatten() 71 | 72 | y_true.extend(f_y_true) 73 | y_pred.extend(f_y_pred) 74 | 75 | print("File: {}".format(fpath)) 76 | cm = confusion_matrix(f_y_true, f_y_pred, labels=[0, 1, 2, 3, 4]) 77 | print_performance(cm) 78 | print(" ") 79 | 80 | y_true = np.asarray(y_true) 81 | y_pred = np.asarray(y_pred) 82 | 83 | cm = confusion_matrix(y_true, y_pred) 84 | acc = np.mean(y_true == y_pred) 85 | mf1 = f1_score(y_true, y_pred, average="macro") 86 | 87 | total = np.sum(cm, axis=1) 88 | 89 | print("CCRRSleepNet (current)") 90 | print_performance(cm) 91 | print("Cohen's kappa score: {}".format(cohen_kappa_score(y_true, y_pred))) 92 | 93 | def main(): 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--data_dir", type=str, default="output", 96 | help="Directory where to load prediction outputs") 97 | args = parser.parse_args() 98 | 99 | if args.data_dir is not None: 100 | perf_overall(data_dir=args.data_dir) 101 | 102 | sharman2017 = np.asarray([ 103 | [7944, 11, 12, 6, 30], 104 | [183, 113, 123, 4, 181], 105 | [48, 4, 3334, 149, 86], 106 | [13, 0, 198, 1088, 0], 107 | [52, 11, 207, 0, 1339] 108 | ], dtype=np.int) 109 | 110 | hassan2017 = np.asarray([ 111 | [3971, 28, 6, 0, 23], 112 | [53, 117, 43, 0, 89], 113 | [70, 5, 1641, 54, 41], 114 | [33, 0, 104, 513, 0], 115 | [41, 24, 84, 1, 655] 116 | ], dtype=np.int) 117 | 118 | tsinalis2016 = np.asarray([ 119 | [2744, 441, 34, 23, 138], 120 | [472, 1654, 262, 8, 366], 121 | [621, 1270, 13696, 1231, 760], 122 | [143, 7, 469, 4966, 6], 123 | [308, 899, 340, 0, 6164] 124 | ], dtype=np.int) 125 | 126 | dong2016 = np.asarray([ 127 | [5022, 577, 188, 19, 395], 128 | [407, 2468, 989, 4, 965], 129 | [130, 630, 27254, 1021, 763], 130 | [13, 0, 1236, 6399, 5], 131 | [103, 258, 609, 0, 9611] 132 | ], dtype=np.int) 133 | 134 | hsu2013 = np.asarray([ 135 | [34, 2, 7, 2, 3], 136 | [0, 20, 23, 3, 9], 137 | [3, 4, 574, 8, 1], 138 | [0, 0, 3, 26, 0], 139 | [3, 5, 13, 4, 213] 140 | ], dtype=np.int) 141 | 142 | liang2012 = np.asarray([ 143 | [195, 24, 4, 0, 3], 144 | [61, 72, 48, 3, 69], 145 | [12, 103, 4078, 216, 220], 146 | [1, 4, 196, 1309, 0], 147 | [8, 8, 22, 6, 1818] 148 | ], dtype=np.int) 149 | 150 | fraiwan2012 = np.asarray([ 151 | [2407, 89, 111, 38, 40], 152 | [56, 185, 52, 8, 48], 153 | [69, 85, 1897, 174, 131], 154 | [14, 9, 86, 482, 3], 155 | [33, 60, 92, 3, 719] 156 | ], dtype=np.int) 157 | 158 | # print(" ") 159 | # print("Sharma (2017)") 160 | # print_performance(sharman2017) 161 | # print(" ") 162 | # print("Hassan (2017)") 163 | # print_performance(hassan2017) 164 | # print(" ") 165 | # print("Tsinalis (2016)") 166 | # print_performance(tsinalis2016) 167 | # print(" ") 168 | # print("Dong (2016)") 169 | # print_performance(dong2016) 170 | # print(" ") 171 | # print("Hsu (2013)") 172 | # print_performance(hsu2013) 173 | # print(" ") 174 | # print("Liang (2012)") 175 | # print_performance(liang2012) 176 | # print(" ") 177 | # print("Fraiwan (2012)") 178 | # print_performance(fraiwan2012) 179 | 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | ''' 4 | https://github.com/akaraspt/deepsleepnet 5 | Copyright 2017 Akara Supratak and Hao Dong. All rights reserved. 6 | ''' 7 | 8 | 9 | import os 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from ccrrsleep.trainer import CCRRFeatureNetTrainer, CCRRSleepNetTrainer 15 | from ccrrsleep.sleep_stage import (NUM_CLASSES, 16 | EPOCH_SEC_LEN, 17 | SAMPLING_RATE) 18 | 19 | from tensorflow.compat.v1 import ConfigProto 20 | from tensorflow.compat.v1 import InteractiveSession 21 | # config = ConfigProto() 22 | # config.gpu_options.allow_growth = True 23 | # session = InteractiveSession(config=config) 24 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7) 25 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 26 | 27 | FLAGS = tf.app.flags.FLAGS 28 | 29 | tf.app.flags.DEFINE_string('data_dir', '../data/data-39/eeg_fpz_cz', 30 | """Directory where to load training data.""") 31 | tf.app.flags.DEFINE_string('output_dir', 'output', 32 | """Directory where to save trained models """ 33 | """and outputs.""") 34 | tf.app.flags.DEFINE_integer('n_folds', 20, 35 | """Number of cross-validation folds.""") 36 | tf.app.flags.DEFINE_integer('fold_idx', 17, 37 | """Index of cross-validation fold to train.""") 38 | tf.app.flags.DEFINE_integer('pretrain_epochs', 30, 39 | """Number of epochs for pretraining CCRRFeatureNet.""") 40 | tf.app.flags.DEFINE_integer('finetune_epochs', 80, 41 | """Number of epochs for fine-tuning CCRRSleepNet.""") 42 | tf.app.flags.DEFINE_boolean('resume', False, 43 | """Whether to resume the training process.""") 44 | 45 | 46 | def pretrain(n_epochs): 47 | trainer = CCRRFeatureNetTrainer( 48 | data_dir=FLAGS.data_dir, 49 | output_dir=FLAGS.output_dir, 50 | n_folds=FLAGS.n_folds, 51 | fold_idx=FLAGS.fold_idx, 52 | batch_size=256, 53 | input_dims=EPOCH_SEC_LEN*100, 54 | n_classes=NUM_CLASSES, 55 | interval_plot_filter=50, 56 | interval_save_model=50, 57 | interval_print_cm=1 58 | ) 59 | pretrained_model_path = trainer.train( 60 | n_epochs=n_epochs, 61 | resume=FLAGS.resume 62 | ) 63 | return pretrained_model_path 64 | 65 | 66 | def finetune(model_path, n_epochs): 67 | trainer = CCRRSleepNetTrainer( 68 | data_dir=FLAGS.data_dir, 69 | output_dir=FLAGS.output_dir, 70 | n_folds=FLAGS.n_folds, 71 | fold_idx=FLAGS.fold_idx, 72 | batch_size=10, 73 | input_dims=EPOCH_SEC_LEN*100, 74 | n_classes=NUM_CLASSES, 75 | seq_length=25, 76 | n_rnn_layers=2, 77 | return_last=False, 78 | interval_plot_filter=50, 79 | interval_save_model=100, 80 | interval_print_cm=10 81 | ) 82 | finetuned_model_path = trainer.finetune( 83 | pretrained_model_path=model_path, 84 | n_epochs=n_epochs, 85 | resume=FLAGS.resume 86 | ) 87 | return finetuned_model_path 88 | 89 | 90 | def main(argv=None): 91 | # Output dir 92 | output_dir = os.path.join(FLAGS.output_dir, "fold{}".format(FLAGS.fold_idx)) 93 | if not FLAGS.resume: 94 | if tf.gfile.Exists(output_dir): 95 | tf.gfile.DeleteRecursively(output_dir) 96 | tf.gfile.MakeDirs(output_dir) 97 | 98 | pretrained_model_path = pretrain( 99 | n_epochs=FLAGS.pretrain_epochs 100 | ) 101 | 102 | finetuned_model_path = finetune( 103 | model_path=pretrained_model_path, 104 | n_epochs=FLAGS.finetune_epochs 105 | ) 106 | 107 | 108 | def make_print_to_file(path='./',fileName=None): 109 | import os 110 | import sys 111 | import datetime 112 | 113 | class Logger(object): 114 | def __init__(self, filename="Default.log", path="./"): 115 | self.terminal = sys.stdout 116 | # self.log = open(os.path.join(path, filename), "a", encoding='utf8', ) 117 | self.log = open(os.path.join(path, filename), "w", encoding='utf8', ) 118 | 119 | def write(self, message): 120 | self.terminal.write(message) 121 | self.log.write(message) 122 | 123 | def flush(self): 124 | pass 125 | if not fileName: 126 | fileName = datetime.datetime.now().strftime('log_' + '%Y_%m_%d_%Hh_%Mmin') 127 | sys.stdout = Logger(fileName + '.log', path=path) 128 | 129 | ############################################################# 130 | # 这里输出之后的所有的输出的print 内容即将写入日志 131 | ############################################################# 132 | print(fileName.center(60, ' ')) 133 | 134 | 135 | if __name__ == "__main__": 136 | np.random.seed(1) 137 | tf.set_random_seed(1) 138 | log_path = './log' 139 | os.makedirs(log_path, exist_ok=True) 140 | make_print_to_file(log_path,"fold_{}_".format(FLAGS.fold_idx)) 141 | tf.compat.v1.app.run() 142 | # tf.compat.v1.app.run(main()) --------------------------------------------------------------------------------