├── cwt.py ├── cwt_band.py ├── scaling.py ├── signal_show.py └── stft.py /cwt.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yunzhong Li" 2 | __maintainer__ = "Yunzhong Li" 3 | __version__ = "1.0.1" 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from matplotlib import pyplot as plt 8 | import pywt 9 | import glob 10 | import os 11 | import matplotlib.colors as colors 12 | 13 | sampling_rate = 400 14 | 15 | def split_signal(data, epoch_length_sec, stride_sec): 16 | ''' split 10 minutes into epochs 17 | Parameters 18 | ---------- 19 | data: {2d numpy array: channels * samples} 20 | The input signal, 16 x 240000. 21 | epoch_length_sec: int 22 | The length (sec) of each epoch. 23 | stride_sec: int 24 | The length (sec) of stride. 25 | epoch_length_sec == stride_sec mean no overlap 26 | ''' 27 | signal = np.array(data, dtype=np.float32) 28 | signal_epochs = [] 29 | samples_in_epoch = epoch_length_sec * sampling_rate 30 | stride = stride_sec * sampling_rate 31 | 32 | # compute dropout indices 33 | drop_indices_c0 = np.where(signal[:, 0] == signal[:, 1])[0] 34 | drop_indices_c1 = np.where(signal[:, 14] == signal[:, 15])[0] 35 | drop_indices = np.intersect1d(drop_indices_c0, drop_indices_c1) 36 | drop_indices = np.append(drop_indices, len(signal)) 37 | 38 | window_start = 0 39 | for i in drop_indices: 40 | epoch_start = window_start 41 | epoch_end = epoch_start + samples_in_epoch 42 | 43 | while epoch_end < i: 44 | signal_epochs.append(signal[epoch_start:epoch_end, :]) 45 | epoch_start += stride 46 | epoch_end += stride 47 | 48 | window_start = i + 1 49 | 50 | return np.array(signal_epochs) 51 | 52 | 53 | def cwt(signal, wavename='cgau8', totalscal=201, sampling_rate=400): 54 | '''do continuous wavelet transform 55 | Parameters 56 | ---------- 57 | signal: {2d numpy array} 58 | The input signal, 200 x 240000. 59 | wavename: {string} 60 | The wave selected to transform signal. 61 | totalscal: {int} 62 | different scales corresponding different frequency bands need>200, set:201 63 | sampling_rate:{int} 64 | The sampling rate of signal, set:400Hz 65 | ''' 66 | 67 | fc = pywt.central_frequency(wavename) # central frequency 68 | cparam = 2 * fc * totalscal 69 | scales = cparam / np.arange(totalscal, 1, -1) # caculate scales 70 | cwt_signal, frequencies = pywt.cwt(signal, scales, wavename, 1.0 / sampling_rate) 71 | return np.abs(cwt_signal), frequencies 72 | 73 | 74 | if __name__ == '__main__': 75 | files = glob.glob1('data/Pat1_1/', '*.pkl') 76 | 77 | for i in range(len(files)): 78 | file_name = files[i] 79 | segment_no, label = file_name[10:-4].split('_') 80 | df = pd.read_pickle(os.path.join('data/Pat1_1/', file_name)) 81 | signal = df.loc[:, 'ch0':'ch15'] 82 | 83 | # signal split 84 | signal = split_signal(signal, 10, 10) 85 | 86 | if len(signal.shape) == 3: 87 | # signal with channel0 88 | signal_c0 = signal[:, :, 0] 89 | 90 | # cwt (200,samples,12000) 91 | cwt_signal, frequencies = cwt(signal_c0) 92 | 93 | for epoch in range(cwt_signal.shape[1]): 94 | ret = [] 95 | ret = np.log10(np.array(cwt_signal[:, epoch, :]) / np.sum(cwt_signal[:, epoch, :], axis=0)) 96 | 97 | # log plot 98 | plt.figure(figsize=(10, 4)) 99 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 100 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 101 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 102 | plt.margins(0, 0) 103 | t = np.arange(0, ret.shape[1]) 104 | plt.pcolormesh(t, frequencies, np.float32(ret), norm=colors.Normalize(vmin=-6, vmax=-1)) 105 | plt.axis('off') 106 | plt.colorbar() 107 | name = str(segment_no) + '_' + str(epoch) + '_' + str(label) 108 | plt.savefig((os.path.join('./image/continuous/Pat1_10sec_1', name + '.jpeg'))) 109 | -------------------------------------------------------------------------------- /cwt_band.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yunzhong Li" 2 | __maintainer__ = "Yunzhong Li" 3 | __version__ = "1.0.1" 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from matplotlib import pyplot as plt 8 | import matplotlib.colors as colors 9 | import pywt 10 | import glob 11 | import os 12 | 13 | PSD_FREQ = np.array([[0, 4], [4, 8], [8, 12], [12, 30], [30, 70], [70, 200]]) 14 | sampling_rate = 400 15 | 16 | 17 | def split_signal(data, epoch_length_sec, stride_sec): 18 | ''' split 10 minutes into epochs 19 | Parameters 20 | ---------- 21 | data: {2d numpy array: channels * samples} 22 | The input signal, 16 x 240000. 23 | epoch_length_sec: int 24 | The length (sec) of each epoch. 25 | stride_sec: int 26 | The length (sec) of stride. 27 | epoch_length_sec == stride_sec mean no overlap 28 | ''' 29 | signal = np.array(data, dtype=np.float32) 30 | signal_epochs = [] 31 | samples_in_epoch = epoch_length_sec * sampling_rate 32 | stride = stride_sec * sampling_rate 33 | 34 | # compute dropout indices 35 | drop_indices_c0 = np.where(signal[:, 0] == signal[:, 1])[0] 36 | drop_indices_c1 = np.where(signal[:, 14] == signal[:, 15])[0] 37 | drop_indices = np.intersect1d(drop_indices_c0, drop_indices_c1) 38 | drop_indices = np.append(drop_indices, len(signal)) 39 | 40 | window_start = 0 41 | for i in drop_indices: 42 | epoch_start = window_start 43 | epoch_end = epoch_start + samples_in_epoch 44 | 45 | while epoch_end < i: 46 | signal_epochs.append(signal[epoch_start:epoch_end, :]) 47 | epoch_start += stride 48 | epoch_end += stride 49 | 50 | window_start = i + 1 51 | 52 | return np.array(signal_epochs) 53 | 54 | 55 | def cwt(signal, wavename='cgau8', totalscal=201, sampling_rate=400): 56 | '''do continuous wavelet transform 57 | Parameters 58 | ---------- 59 | signal: {2d numpy array} 60 | The input signal, 200 x 240000. 61 | wavename: {string} 62 | The wave selected to transform signal. 63 | totalscal: {int} 64 | different scales corresponding different frequency bands need>200, set:201 65 | sampling_rate:{int} 66 | The sampling rate of signal, set:400Hz 67 | ''' 68 | 69 | fc = pywt.central_frequency(wavename) # central frequency 70 | cparam = 2 * fc * totalscal 71 | scales = cparam / np.arange(totalscal, 1, -1) # caculate scales 72 | cwt_signal, frequencies = pywt.cwt(signal, scales, wavename, 1.0 / sampling_rate) 73 | return np.abs(cwt_signal), frequencies 74 | 75 | 76 | if __name__ == '__main__': 77 | files = glob.glob1('data/Pat1_0/', '*.pkl') 78 | 79 | for i in range(len(files)): 80 | file_name = files[i] 81 | segment_no, label = file_name[10:-4].split('_') 82 | df = pd.read_pickle(os.path.join('data/Pat1_0/', file_name)) 83 | signal = df.loc[:, 'ch0':'ch15'] 84 | 85 | # signal split 86 | signal = split_signal(signal, 30, 30) 87 | 88 | if len(signal.shape) == 3: 89 | # signal with channel0 90 | signal_c0 = signal[:, :, 0] 91 | 92 | # cwt 93 | cwt_signal, frequencies = cwt(signal_c0) 94 | 95 | for epoch in range(cwt_signal.shape[1]): 96 | ret = [] 97 | for freq_band in PSD_FREQ: 98 | tmp = (frequencies >= freq_band[0]) & (frequencies < freq_band[1]) 99 | ret.append((cwt_signal[tmp, epoch, :].mean(0))) 100 | ret = np.log10(np.array(ret) / np.sum(ret, axis=0)) 101 | 102 | # log plot 103 | plt.figure(figsize=(2.56, 2.56)) 104 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 105 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 106 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 107 | plt.margins(0, 0) 108 | t = np.arange(0, ret.shape[1]) 109 | BAND = np.array(['0', 'Delta', 'Theta', 'Alpha', 'Beta', 'low-gamma', 'high-gamma']) 110 | plt.pcolormesh(t, BAND, np.float32(ret), norm=colors.Normalize(vmin=-2, vmax=0)) 111 | plt.axis('off') 112 | name = str(segment_no) + '_' + str(epoch) + '_' + str(1) + '_' + str(label) 113 | plt.savefig((os.path.join('./image/band/Pat1_30sec_0', name + '.jpeg'))) 114 | -------------------------------------------------------------------------------- /scaling.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | __author__ = "Yunzhong Li" 4 | __maintainer__ = "Yunzhong Li" 5 | __version__ = "1.0.1" 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import pywt 10 | import glob 11 | import os 12 | 13 | def cwt(signal, wavename='cgau8', totalscal=201, sampling_rate=400): 14 | '''do continuous wavelet transform 15 | 16 | Parameters 17 | ---------- 18 | signal: {2d numpy array} 19 | The input signal, 200 x 240000. 20 | wavename: {string} 21 | The wave selected to transform signal. 22 | totalscal: {int} 23 | different scales corresponding different frequency bands need>200, set:201 24 | sampling_rate:{int} 25 | The sampling rate of signal, set:400Hz 26 | ''' 27 | 28 | fc = pywt.central_frequency(wavename) # central frequency 29 | cparam = 2 * fc * totalscal 30 | scales = cparam / np.arange(totalscal, 1, -1) # caculate scales 31 | cwt_signal, frequencies = pywt.cwt(signal, scales, wavename, 1.0 / sampling_rate) 32 | return np.abs(cwt_signal), frequencies 33 | 34 | if __name__ == '__main__': 35 | files = glob.glob1('data/Pat1Train', '*.pkl') 36 | min_band = np.ones((16, 200), dtype=float) 37 | max_band = np.zeros((16, 200), dtype=float) 38 | 39 | for file in range(len(files)): 40 | file_name = files[file] 41 | df = pd.read_pickle(os.path.join('data/Pat1Train', file_name)) 42 | signal = df.loc[:, 'ch0':'ch15'] 43 | signal = np.array(signal) 44 | 45 | for channel in range(signal.shape[1]): 46 | cwt_signal, frequencies = cwt(signal[:, channel]) 47 | for i in range(len(frequencies)): 48 | if min_band[channel, i] > np.min(cwt_signal[i]): 49 | min_band[channel, i] = np.min(cwt_signal[i]) 50 | if max_band[channel, i] < np.max(cwt_signal[i]): 51 | max_band[channel, i] = np.max(cwt_signal[i]) 52 | print(file) 53 | 54 | print(min_band) 55 | print(max_band) 56 | np.savetxt('./min.txt', min_band) 57 | np.savetxt('./max.txt', max_band) -------------------------------------------------------------------------------- /signal_show.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from matplotlib import pyplot as plt 3 | import numpy as np 4 | 5 | files = pd.read_pickle('data/Pat1_1/Pat1Train_121_1.pkl') 6 | 7 | signal = np.array(files['ch1']) 8 | 9 | t = np.arange(0, len(signal)) 10 | plt.plot(t, signal) 11 | plt.show() 12 | -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from matplotlib import pyplot as plt 4 | from scipy.signal import stft 5 | import glob 6 | import os 7 | 8 | def normalization(signal): 9 | '''normaliz each frequency band to 0-1 range separately 10 | 11 | Parameters 12 | ---------- 13 | signal: {2d numpy array} 14 | The input signal, 200 x 240000. 15 | ''' 16 | minVals = min(signal) 17 | maxVals = max(signal) 18 | ranges = maxVals - minVals 19 | normData = np.zeros(np.shape(signal)) 20 | m = signal.shape 21 | normData = signal - np.tile(minVals, m) 22 | normData = normData/np.tile(ranges, m) 23 | return normData 24 | 25 | def stft_transform(signal, nperseg=400,noverlap=1, fs=400): 26 | '''do short-time fourier transform 27 | 28 | Parameters 29 | ---------- 30 | signal: {2d numpy array} 31 | The input signal, 200 x 240000. 32 | nperseg: {int} 33 | The length of each window, due to 400hz sampling rate, 400 points length to 1sec 34 | noverlap: {int} 35 | Set noverlap = 1. No noverlap 36 | fs: {int} 37 | The sampling rate of signal, set:400Hz 38 | ''' 39 | f, t, zxx = stft(signal, nperseg=nperseg, noverlap=noverlap, fs=fs) 40 | return f, t, np.abs(zxx) 41 | 42 | if __name__ == '__main__': 43 | files = glob.glob1('data/', '*.pkl') 44 | 45 | for i in range(len(files)): 46 | file_name = files[i] 47 | segment_no, label = file_name[10:-4].split('_') 48 | df = pd.read_pickle(os.path.join('data', file_name)) 49 | sampling_rate = 400 50 | signal = df['ch13'] 51 | 52 | # stft 53 | f, t, zxx = stft_transform(signal) 54 | 55 | # normalization 56 | norm_signal = [[0 for i in range(zxx.shape[1])] for i in range(len(zxx))] 57 | for i in range(len(zxx)): 58 | norm_signal[i] = normalization(zxx[i, :]) 59 | norm_signal = np.array(norm_signal) 60 | 61 | # split into 20 30sec window 62 | for i in range(20): 63 | plt.figure(figsize=(8, 2)) 64 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 65 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 66 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 67 | plt.margins(0, 0) 68 | plt.pcolormesh(t[i*30:(i+1)*30], f, norm_signal[:,i*30:(i+1)*30]) 69 | plt.axis('off') 70 | name = str(segment_no) + '_' + str(i) + '_' + str(label) 71 | plt.savefig((os.path.join('./image_stft',name+'.png'))) 72 | --------------------------------------------------------------------------------