├── README.md ├── CNNLSTM.py ├── LeaveOneOut_acc.py ├── LICENSE └── CNNLSTM_VisualizationTech.py /README.md: -------------------------------------------------------------------------------- 1 | # Subject-Independent-Drowsiness-Recognition-from-Single-Channel-EEG-with-an-Interpretable-CNN-LSTM 2 | Pytorch implementation of the paper "Subject-Independent Drowsiness Recognition from Single-Channel EEG with an Interpretable CNN-LSTM model". 3 | https://doi.org/10.1109/CW52790.2021.00041 4 | 5 | If you find the codes useful, pls cite the paper: 6 | 7 | J. Cui et al., "Subject-Independent Drowsiness Recognition from Single-Channel EEG with an Interpretable CNN-LSTM model," 2021 International Conference on Cyberworlds (CW), 2021, pp. 201-208, doi: 10.1109/CW52790.2021.00041. 8 | 9 | The project contains 3 code files. They are implemented with Python 3.6.6. 10 | 11 | "CNNLSTM.py" contains the model. required library: torch 12 | 13 | "LeaveOneOut_acc.py" contains the leave-one-subject-out method to get the classifcation accuracies. It requires the computer to have cuda supported GPU installed. required library:torch,scipy,numpy,sklearn 14 | 15 | "CNNLSTM_VisualizationTech.py" contains the visualization technique based on the modification of the LSTM model. It requires the computer to have cuda supported GPU installed. required library:torch,scipy,numpy,matplotlib,mne 16 | 17 | The processed dataset has been uploaded to: https://figshare.com/articles/dataset/EEG_driver_drowsiness_dataset/14273687 18 | 19 | If you have any problems, please Contact Dr. Cui Jian at cuij0006@ntu.edu.sg 20 | -------------------------------------------------------------------------------- /CNNLSTM.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Sep 2 18:17:29 2019 4 | 5 | @author: JIAN 6 | """ 7 | import torch 8 | class CNNLSTM(torch.nn.Module): 9 | """ 10 | The codes implement the CNN model proposed in the paper "Subject-Independent Drowsiness Recognition from Single-Channel EEG with an Interpretable CNN-LSTM model". 11 | The network is designed to classify 1D drowsy and alert EEG signals for the purposed of driver drowsiness recognition. 12 | 13 | """ 14 | def __init__(self): 15 | super(CNNLSTM, self).__init__() 16 | self.feature=32 17 | self.padding= torch.nn.ReplicationPad2d((31,32,0,0)) 18 | self.conv = torch.nn.Conv2d(1,self.feature,(1,64))#,padding=(0,32),padding_mode='replicate') 19 | self.batch = Batchlayer(self.feature) 20 | self.avgpool = torch.nn.AvgPool2d((1,8)) 21 | self.fc = torch.nn.Linear(32, 2) 22 | self.softmax=torch.nn.LogSoftmax(dim=1) 23 | self.softmax1=torch.nn.Softmax(dim=1) 24 | self.lstm=torch.nn.LSTM(32, 2) 25 | 26 | def forward(self, source): 27 | source = self.padding(source) 28 | source = self.conv(source) 29 | source = self.batch(source) 30 | 31 | source = torch.nn.ELU()(source) 32 | source=self.avgpool(source) 33 | source =source.squeeze() 34 | source=source.permute(2, 0, 1) 35 | source = self.lstm(source) 36 | source=source[1][0].squeeze() 37 | source = self.softmax(source) 38 | 39 | return source 40 | 41 | """ 42 | We use the batch normalization layer implemented by ourselves for this model instead using the one provided by the Pytorch library. 43 | In this implementation, we do not use momentum and initialize the gamma and beta values in the range (-0.1,0.1). 44 | We have got slightly increased accuracy using our implementation of the batch normalization layer. 45 | """ 46 | def normalizelayer(data): 47 | eps=1e-05 48 | a_mean=data-torch.mean(data, [0,2,3],True).expand(int(data.size(0)), int(data.size(1)), int(data.size(2)),int(data.size(3))) 49 | b=torch.div(a_mean,torch.sqrt(torch.mean((a_mean)**2, [0,2,3],True)+eps).expand(int(data.size(0)), int(data.size(1)), int(data.size(2)),int(data.size(3)))) 50 | 51 | return b 52 | 53 | class Batchlayer(torch.nn.Module): 54 | def __init__(self, dim): 55 | super(Batchlayer, self).__init__() 56 | self.gamma=torch.nn.Parameter(torch.Tensor(1,dim,1,1)) 57 | self.beta=torch.nn.Parameter(torch.Tensor(1,dim,1,1)) 58 | self.gamma.data.uniform_(-0.1, 0.1) 59 | self.beta.data.uniform_(-0.1, 0.1) 60 | 61 | def forward(self, input): 62 | data=normalizelayer(input) 63 | gammamatrix=self.gamma.expand(int(data.size(0)), int(data.size(1)), int(data.size(2)),int(data.size(3))) 64 | betamatrix = self.beta.expand(int(data.size(0)), int(data.size(1)), int(data.size(2)),int(data.size(3))) 65 | 66 | return data*gammamatrix+betamatrix 67 | -------------------------------------------------------------------------------- /LeaveOneOut_acc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Sep 2 18:17:29 2019 4 | 5 | @author: JIAN 6 | """ 7 | import torch 8 | import scipy.io as sio 9 | import numpy as np 10 | from sklearn.metrics import accuracy_score 11 | import torch.optim as optim 12 | from CNNLSTM import CNNLSTM 13 | 14 | torch.cuda.empty_cache() 15 | torch.manual_seed(0) 16 | 17 | """ 18 | This file performs leave-one-subject cross-subject classification on the driver drowsiness dataset. 19 | 20 | The dataset is available from: 21 | https://figshare.com/articles/dataset/EEG_driver_drowsiness_dataset/14273687 22 | 23 | THe data file contains 3 variables and they are EEGsample, substate and subindex. 24 | "EEGsample" contains 2022 EEG samples of size 20x384 from 11 subjects. 25 | Each sample is a 3s EEG data with 128Hz from 30 EEG channels. 26 | 27 | The names and their corresponding index are shown below: 28 | Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCZ, FC4, FT8, T3, C3, Cz, C4, T4, TP7, CP3, CPz, CP4, TP8, T5, P3, PZ, P4, T6, O1, Oz O2 29 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,24, 25, 26, 27, 28, 29 30 | 31 | Only the channel Oz is used. 32 | 33 | "subindex" is an array of 2022x1. It contains the subject indexes from 1-11 corresponding to each EEG sample. 34 | "substate" is an array of 2022x1. It contains the labels of the samples. 0 corresponds to the alert state and 1 correspond to the drowsy state. 35 | 36 | This file prints leave-one-out accuracies for each subject and the overall accuracy. 37 | The overall accuracy for one run is around 72%-73%. 38 | 39 | If you have met any problems, you can contact Dr. Cui Jian at cuij0006@ntu.edu.sg 40 | """ 41 | 42 | def run(): 43 | 44 | # load data from the file 45 | filename = r'dataset.mat' 46 | 47 | tmp = sio.loadmat(filename) 48 | xdata=np.array(tmp['EEGsample']) 49 | label=np.array(tmp['substate']) 50 | subIdx=np.array(tmp['subindex']) 51 | 52 | label.astype(int) 53 | subIdx.astype(int) 54 | 55 | samplenum=label.shape[0] 56 | 57 | # there are 11 subjects in the dataset. Each sample is 3-seconds data from 30 channels with sampling rate of 128Hz. 58 | channelnum=30 59 | subjnum=11 60 | samplelength=3 61 | sf=128 62 | 63 | # define the learning rate, batch size and epoches 64 | lr=1e-2 65 | batch_size = 50 66 | n_epoch =15 67 | 68 | # ydata contains the label of samples 69 | ydata=np.zeros(samplenum,dtype=np.longlong) 70 | 71 | for i in range(samplenum): 72 | ydata[i]=label[i] 73 | 74 | # only channel 28 is used, which corresponds to the Oz channel 75 | selectedchan=[28] 76 | 77 | # update the xdata and channel number 78 | xdata=xdata[:,selectedchan,:] 79 | channelnum=len(selectedchan) 80 | 81 | # the result stores accuracies of every subject 82 | results=np.zeros(subjnum) 83 | 84 | 85 | 86 | # it performs leave-one-subject-out training and classfication 87 | # for each iteration, the subject i is the testing subject while all the other subjects are the training subjects. 88 | for i in range(1,subjnum+1): 89 | 90 | # form the training data 91 | trainindx=np.where(subIdx != i)[0] 92 | xtrain=xdata[trainindx] 93 | x_train = xtrain.reshape(xtrain.shape[0],1,channelnum, samplelength*sf) 94 | y_train=ydata[trainindx] 95 | 96 | 97 | # form the testing data 98 | testindx=np.where(subIdx == i)[0] 99 | xtest=xdata[testindx] 100 | x_test = xtest.reshape(xtest.shape[0], 1,channelnum, samplelength*sf) 101 | y_test=ydata[testindx] 102 | 103 | 104 | train = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) 105 | train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True) 106 | 107 | # load the CNN model to deal with 1D EEG signals 108 | my_net = CNNLSTM().double().cuda() 109 | 110 | 111 | optimizer = optim.Adam(my_net.parameters(), lr=lr) 112 | loss_class = torch.nn.NLLLoss().cuda() 113 | 114 | for p in my_net.parameters(): 115 | p.requires_grad = True 116 | 117 | # train the classifier 118 | for epoch in range(n_epoch): 119 | for j, data in enumerate(train_loader, 0): 120 | inputs, labels = data 121 | 122 | input_data = inputs.cuda() 123 | class_label = labels.cuda() 124 | 125 | my_net.zero_grad() 126 | my_net.train() 127 | 128 | class_output= my_net(input_data) 129 | err_s_label = loss_class(class_output, class_label) 130 | err = err_s_label 131 | 132 | err.backward() 133 | optimizer.step() 134 | 135 | # test the results 136 | my_net.train(False) 137 | with torch.no_grad(): 138 | x_test = torch.DoubleTensor(x_test).cuda() 139 | answer = my_net(x_test) 140 | probs=answer.cpu().numpy() 141 | preds = probs.argmax(axis = -1) 142 | acc=accuracy_score(y_test, preds) 143 | 144 | print(acc) 145 | results[i-1]=acc 146 | 147 | 148 | print('mean accuracy:',np.mean(results)) 149 | 150 | if __name__ == '__main__': 151 | run() 152 | 153 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /CNNLSTM_VisualizationTech.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Sep 2 18:17:29 2019 4 | 5 | @author: JIAN 6 | """ 7 | """ 8 | This file implement the visualization technique proposed in the paper. 9 | The extracted dataset is available from: 10 | https://figshare.com/articles/dataset/EEG_driver_drowsiness_dataset/14273687 11 | """ 12 | import torch 13 | import scipy.io as sio 14 | import numpy as np 15 | import torch.optim as optim 16 | import matplotlib.pyplot as plt 17 | from matplotlib.collections import LineCollection 18 | from scipy.integrate import simps 19 | from mne.time_frequency import psd_array_multitaper 20 | import matplotlib.gridspec as gridspec 21 | from CNNLSTM import CNNLSTM 22 | import torch.nn as nn 23 | 24 | plt.rcParams["mathtext.default"]='regular' 25 | plt.rcParams.update({'font.size': 15}) 26 | 27 | torch.cuda.empty_cache() 28 | torch.manual_seed(0) 29 | 30 | class FeatureVis(): 31 | def __init__(self, model): 32 | self.model = model 33 | self.model.eval() 34 | self.softmax=nn.Softmax(dim=1) 35 | 36 | def generate_heatmap(self, allsignals,sampleidx,subid,samplelabel,multichannelsignal,likelihood): 37 | """ 38 | input: 39 | allsignals: all the signals in the batch 40 | sampleidx: the index of the sample 41 | subid: the ID of the subject 42 | samplelabel: the ground truth label of the sample 43 | multichannelsignal: the signals from all channels for the sample 44 | likelihood: the likelihood of the sample to be classified into alert and drowsy state 45 | """ 46 | 47 | if likelihood[0]>likelihood[1]: 48 | state=0 49 | else: 50 | state=1 51 | 52 | if samplelabel==0: 53 | labelstr='alert' 54 | else: 55 | labelstr='drowsy' 56 | 57 | fig = plt.figure(figsize=(14,6)) 58 | 59 | # devide the figure layout 60 | gridlayout = gridspec.GridSpec(ncols=2, nrows=5, figure=fig,wspace=0.2, hspace=0.5) 61 | axis0 = fig.add_subplot(gridlayout[0:2,0]) 62 | axis1 = fig.add_subplot(gridlayout[4,0]) 63 | axis2 = fig.add_subplot(gridlayout[0:5,1]) 64 | axis3 = fig.add_subplot(gridlayout[2:4,0]) 65 | 66 | # do some preparations 67 | rawsignal=allsignals[sampleidx].cpu().detach().numpy().squeeze() 68 | channelnum=multichannelsignal.shape[0] 69 | samplelength=multichannelsignal.shape[1] 70 | maxvalue=np.max(np.abs(rawsignal)) 71 | 72 | ## calculate the heatmaps for the sample 73 | source = self.model.padding(allsignals) 74 | source = self.model.conv(source) 75 | source = self.model.batch(source) 76 | source = torch.nn.ELU()(source) 77 | source=self.model.avgpool(source) 78 | source =source.squeeze() 79 | source=source.permute(2, 0, 1) 80 | source = self.model.lstm(source)[0] 81 | 82 | hiddenstates=source[:,sampleidx,:].squeeze() 83 | hiddenstates=self.softmax(hiddenstates) 84 | hiddenstates=hiddenstates[:,state].cpu().detach().numpy() 85 | 86 | flength=hiddenstates.shape[0] 87 | duplication=int(384/hiddenstates.shape[0]) 88 | 89 | heatmap=np.zeros(flength) 90 | for i in range(1,flength): 91 | heatmap[i]=hiddenstates[i]-hiddenstates[i-1] 92 | 93 | heatmap= (heatmap-np.mean(heatmap)) / np.sqrt(np.sum(heatmap**2)/(samplelength)) 94 | 95 | relative_heatmap=np.repeat(heatmap,duplication) 96 | accumulated_heatmap=np.repeat(hiddenstates,duplication) 97 | 98 | 99 | fig.suptitle('Subject:'+str(int(subid))+' '+'Label:'+labelstr+' '+'$P_{alert}=$'+str(round(likelihood[0],2))+' $P_{drowsy}=$'+str(round(likelihood[1],2)),fontsize=25)#+' '+str(envmap[-1]))#, fontsize=12) 100 | 101 | 102 | ## calculate the band power components 103 | psd, freqs = psd_array_multitaper(rawsignal, 128, adaptive=True,normalization='full', verbose=0) 104 | freq_res = freqs[1] - freqs[0] 105 | bandpowers=np.zeros(4) 106 | 107 | idx_band = np.logical_and(freqs >= 1, freqs <= 4) 108 | bandpowers[0] = simps(psd[idx_band], dx=freq_res) 109 | idx_band = np.logical_and(freqs >= 4, freqs <= 8) 110 | bandpowers[1] = simps(psd[idx_band], dx=freq_res) 111 | idx_band = np.logical_and(freqs >= 8, freqs <= 12) 112 | bandpowers[2] = simps(psd[idx_band], dx=freq_res) 113 | idx_band = np.logical_and(freqs >= 12, freqs <= 30) 114 | bandpowers[3] = simps(psd[idx_band], dx=freq_res) 115 | 116 | totalpower=simps(psd, dx=freq_res) 117 | if totalpower<0.00000001: 118 | bandpowers=np.zeros(4) 119 | else: 120 | bandpowers /= totalpower 121 | 122 | barx=np.arange(1, 5) 123 | axis1.bar(barx,bandpowers) 124 | axis1.set_xlim([0,5]) 125 | axis1.set_ylim([0,0.8]) 126 | 127 | axis1.set_ylabel("Ratio",fontsize=20) 128 | 129 | axis1.set_xticks([1,2,3,4]) 130 | axis1.set_xticklabels(['Delta','Theta','Alpha','Beta'],fontsize=20) 131 | 132 | # draw the heatmap 133 | xx= np.arange(1, (samplelength+1)) 134 | axis0.set_xticks([]) 135 | axis0.set_ylim([-maxvalue-10,maxvalue+10]) 136 | axis0.set_xlim([0,(samplelength+1)]) 137 | axis0.set_ylabel("$\mu V$",fontsize=20) 138 | 139 | points = np.array([xx, rawsignal]).T.reshape(-1, 1, 2) 140 | segments = np.concatenate([points[:-1], points[1:]], axis=1) 141 | 142 | norm = plt.Normalize(0.3, 0.8) 143 | 144 | lc = LineCollection(segments, cmap='viridis', norm=norm) 145 | lc.set_array(accumulated_heatmap) 146 | lc.set_linewidth(2) 147 | axis0.add_collection(lc) 148 | fig.colorbar(lc,ax=axis0,orientation="horizontal")#, ticks=[0, 0.25,0.5,0.75,1]) 149 | 150 | # draw the relative heatmap 151 | xx= np.arange(1, (samplelength+1)) 152 | axis3.set_xticks([]) 153 | axis3.set_ylim([-maxvalue-10,maxvalue+10]) 154 | axis3.set_xlim([0,(samplelength+1)]) 155 | axis3.set_ylabel(r"$\mu V$",fontsize=20) 156 | 157 | points = np.array([xx, rawsignal]).T.reshape(-1, 1, 2) 158 | segments = np.concatenate([points[:-1], points[1:]], axis=1) 159 | 160 | 161 | norm = plt.Normalize(-1, 1) 162 | lc = LineCollection(segments, cmap='viridis', norm=norm) 163 | lc.set_array(relative_heatmap) 164 | lc.set_linewidth(2) 165 | axis3.add_collection(lc) 166 | fig.colorbar(lc,ax=axis3,orientation="horizontal", ticks=[-1, -0.5, 0, 0.5, 1]) 167 | 168 | # draw all the signals 169 | thespan=np.percentile(multichannelsignal,98) 170 | yttics=np.zeros(channelnum) 171 | for i in range(channelnum): 172 | yttics[i]=i*thespan 173 | 174 | axis2.set_ylim([-thespan,thespan*channelnum]) 175 | axis2.set_xlim([0,samplelength+1]) 176 | 177 | axis2.set_xticks([1,100,200,300,384]) 178 | 179 | labels=['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FT7', 'FC3', 'FCZ', 'FC4', 'FT8', 'T3', 'C3', 'Cz', 'C4', 'T4', 'TP7', 'CP3', 'CPz', 'CP4', 'TP8','T5', 'P3', 'PZ', 'P4', 'T6', 'O1', 'Oz','O2'] 180 | 181 | plt.sca(axis2) 182 | plt.yticks(yttics, labels,fontsize=13) 183 | plt.xticks(fontsize=20) 184 | 185 | heatmap1=np.zeros((channelnum,samplelength))-1 186 | heatmap1[-2,:]=relative_heatmap 187 | xx=np.arange(1,samplelength+1) 188 | 189 | for i in range(0,channelnum): 190 | y=multichannelsignal[i,:]+thespan*(i) 191 | dydx=heatmap1[i,:] 192 | 193 | points = np.array([xx, y]).T.reshape(-1, 1, 2) 194 | segments = np.concatenate([points[:-1], points[1:]], axis=1) 195 | 196 | norm = plt.Normalize(-1, 1) 197 | lc = LineCollection(segments, cmap='viridis', norm=norm) 198 | lc.set_array(dydx) 199 | lc.set_linewidth(2) 200 | axis2.add_collection(lc) 201 | 202 | return source 203 | 204 | 205 | def run(): 206 | filename = r'dataset.mat' 207 | tmp = sio.loadmat(filename) 208 | xdata=np.array(tmp['EEGsample']) 209 | label=np.array(tmp['substate']) 210 | subIdx=np.array(tmp['subindex']) 211 | 212 | label.astype(int) 213 | subIdx.astype(int) 214 | samplenum=label.shape[0] 215 | 216 | channelnum=30 217 | classes=2 218 | subjnum=11 219 | samplelength=3 220 | 221 | lr=1e-2# for smalle net 222 | sf=128 223 | batch_size = 50 224 | n_epoch =15 225 | 226 | ydata=np.zeros(samplenum,dtype=np.longlong) 227 | 228 | for i in range(samplenum): 229 | ydata[i]=label[i] 230 | 231 | selectedchan=[28] 232 | rawx=xdata 233 | 234 | xdata=xdata[:,selectedchan,:] 235 | channelnum=len(selectedchan) 236 | 237 | # you can set the subject id here 238 | for i in range(2,3): 239 | trainindx=np.where(subIdx!= i)[0] 240 | xtrain=xdata[trainindx] 241 | x_train = xtrain.reshape(xtrain.shape[0],1,channelnum, samplelength*sf) 242 | y_train=ydata[trainindx] 243 | 244 | testindx=np.where(subIdx == i)[0] 245 | 246 | xtest=xdata[testindx] 247 | rawxdata=rawx[testindx] 248 | x_test = xtest.reshape(xtest.shape[0], 1,channelnum, samplelength*sf) 249 | y_test=ydata[testindx] 250 | 251 | train = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) 252 | train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True) 253 | 254 | my_net = CNNLSTM().double().cuda() 255 | 256 | optimizer = optim.Adam(my_net.parameters(), lr=lr) 257 | loss_class = torch.nn.NLLLoss().cuda() 258 | 259 | for p in my_net.parameters(): 260 | p.requires_grad = True 261 | 262 | for epoch in range(n_epoch): 263 | for j, data in enumerate(train_loader, 0): 264 | inputs, labels = data 265 | input_data = inputs.cuda() 266 | class_label = labels.cuda() 267 | my_net.zero_grad() 268 | my_net.train() 269 | class_output= my_net(input_data) 270 | err_s_label = loss_class(class_output, class_label) 271 | err = err_s_label 272 | err.backward() 273 | optimizer.step() 274 | 275 | my_net.train(False) 276 | with torch.no_grad(): 277 | x_test = torch.DoubleTensor(x_test).cuda() 278 | answer = my_net(x_test) 279 | probs=np.exp(answer.cpu().numpy()) 280 | 281 | preds = probs.argmax(axis = -1) 282 | sampleVis =FeatureVis(my_net) 283 | 284 | 285 | # you can set the sample index here 286 | sampleidx=0 287 | sampleVis.generate_heatmap(allsignals=x_test,sampleidx=sampleidx,subid=i,samplelabel=y_test[sampleidx],multichannelsignal=rawxdata[sampleidx],likelihood=probs[sampleidx]) 288 | 289 | if __name__ == '__main__': 290 | run() --------------------------------------------------------------------------------