├── .gitignore ├── requirements.txt ├── model.py ├── README.md ├── dataset.py ├── statistics.py ├── main_training_and_reviewer_testing.py └── main_training_and_institution_testing.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | .idea/ 4 | 5 | __pycache__/ 6 | 7 | DATASET_FNUSA/ 8 | DATASET_MAYO/ 9 | 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.1 2 | scipy==1.4.1 3 | pandas==1.0.3 4 | torch==1.4.0 5 | tqdm==4.43.0 6 | scikit-learn==0.22.1 7 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | 5 | class NN(nn.Module): 6 | def __init__(self,NFILT=256,NOUT=4): 7 | super(NN,self).__init__() 8 | self.conv0 = nn.Conv2d(1,NFILT,kernel_size=(200,3),padding=(0,1),bias=False) 9 | self.bn0 = nn.BatchNorm2d(NFILT) 10 | self.gru = nn.GRU(input_size=NFILT,hidden_size=128,num_layers=1,batch_first=True,bidirectional=False) 11 | self.fc1 = nn.Linear(128,NOUT) 12 | 13 | 14 | 15 | def forward(self, x): 16 | x = F.relu(self.bn0(self.conv0(x))) 17 | x = x.squeeze().permute(0,2,1) 18 | x,_ = self.gru(x) 19 | x = F.dropout(x,p=0.5,training=self.training) 20 | x = self.fc1(x) 21 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NoiseDetectionCNN-GRU 2 | Code example for data loading pipeline and CNN-GRU neural network model training. 3 | 4 | # Installation: 5 | 0. download dataset and example code 6 | 1. conda create -n NoiseDetection python=3.6 pip=20.0.2 7 | 2. conda activate NoiseDetection 8 | 3. pip install requirements.txt 9 | 10 | # References: 11 | Nejedly, P., Kremen, V., Sladky, V. et al. Multicenter intracranial EEG dataset for classification of graphoelements and artifactual signals. Sci Data 7, 179 (2020). https://doi.org/10.1038/s41597-020-0532-5 12 | 13 | Nejedly, P., Cimbalnik, J., Klimes, P. et al. Intracerebral EEG Artifact Identification Using Convolutional Neural Networks. Neuroinform (2018). https://doi.org/10.1007/s12021-018-9397-6 14 | 15 | Nejedly, P., Kremen, V., Sladky, V. et al. Exploiting Graphoelements and Convolutional Neural Networks with Long Short Term Memory for Classification of the Human Electroencephalogram. Sci Rep 9, 11383 (2019). https://doi.org/10.1038/s41598-019-47854-6 16 | 17 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import copy 4 | import scipy.signal as signal 5 | import scipy.stats as stats 6 | import scipy.io as sio 7 | import tqdm 8 | 9 | class Dataset: 10 | def __init__(self,path): 11 | self.path = path 12 | if self.path[-1] != '/': 13 | self.path += '/' 14 | self.df = pd.read_csv(self.path + 'segments.csv') 15 | self.NFFF = 200 16 | 17 | def __len__(self): 18 | return len(self.df) 19 | 20 | def __getitem__(self, item): 21 | sid = self.df.iloc[item]['segment_id'] 22 | target = self.df.iloc[item]['category_id'] 23 | data = sio.loadmat(self.path+'{}'.format(sid))['data'] 24 | _,_, data = signal.spectrogram(data[0,:],fs=5000,nperseg=256,noverlap=128,nfft=1024) 25 | 26 | data = data[:self.NFFF,:] 27 | data = stats.zscore(data,axis=1) 28 | data = np.expand_dims(data,axis=0) 29 | return data,target 30 | 31 | def split_reviewer(self,reviewer_id): 32 | train = copy.deepcopy(self) 33 | valid = copy.deepcopy(self) 34 | 35 | idx = self.df['reviewer_id']!=reviewer_id 36 | 37 | train.df = train.df[idx].reset_index(drop=True) 38 | valid.df = valid.df[np.logical_not(idx)].reset_index(drop=True) 39 | return train,valid 40 | 41 | def split_random(self,N_valid): 42 | self.df = self.df.sample(frac=1).reset_index(drop=True) 43 | train = copy.deepcopy(self) 44 | valid = copy.deepcopy(self) 45 | 46 | train.df = train.df.iloc[N_valid:].reset_index(drop=True) 47 | valid.df = valid.df.iloc[:N_valid].reset_index(drop=True) 48 | return train,valid 49 | 50 | def integrity_check(self): 51 | # iterate through dataset and check if all the files might be correctly loaded 52 | try: 53 | for i in tqdm.tqdm(range(len(self))): 54 | x = self.__getitem__(i) 55 | except Exception as exc: 56 | raise exc 57 | 58 | def remove_powerline_noise_class(self): 59 | self.df = self.df[self.df['category_id']!=0] 60 | self.df['category_id'] = self.df['category_id'] - 1 61 | self.df = self.df.reset_index(drop=True) 62 | return self 63 | 64 | 65 | 66 | if __name__ == "__main__": 67 | dataset_fnusa = Dataset('./DATASET_FNUSA/').integrity_check() 68 | dataset_mayo = Dataset('./DATASET_MAYO/').integrity_check() 69 | -------------------------------------------------------------------------------- /statistics.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from scipy.special import softmax,expit 5 | from sklearn.metrics import f1_score,confusion_matrix,cohen_kappa_score,roc_curve,roc_auc_score,average_precision_score 6 | 7 | 8 | class Statistics(object): 9 | def __init__(self): 10 | self.target = [] 11 | self.logits = [] 12 | 13 | def reset(self): 14 | self.target = [] 15 | self.logits = [] 16 | 17 | @staticmethod 18 | def idx2onehot(idx_array): 19 | y = np.zeros((idx_array.shape[0], idx_array.max() + 1)) 20 | y[np.arange(y.shape[0]), idx_array] = 1 21 | return y 22 | 23 | @staticmethod 24 | def F1(conf): 25 | x0 = np.sum(conf, 0) 26 | x1 = np.sum(conf, 1) 27 | dg = np.diag(conf) 28 | f1 = 2 * dg / (x0 + x1) 29 | return f1 30 | 31 | @staticmethod 32 | def Kappa(conf): 33 | x0 = np.sum(conf, 0) 34 | x1 = np.sum(conf, 1) 35 | N = np.sum(np.sum(conf)) 36 | ef = np.sum(x0 * x1 / N) 37 | dg = np.sum(np.diag(conf)) 38 | K = (dg - ef) / (N - ef) 39 | return K 40 | 41 | def append(self,target,logits): 42 | self.logits.append(logits.data.cpu().numpy()) 43 | self.target.append(target.data.cpu().numpy()) 44 | 45 | @staticmethod 46 | def random_auprc(target): 47 | y_chance = np.zeros((target.max()+1,)) 48 | for i in range(target.max()+1): 49 | y_chance[i] = len(target[target==i]) / len(target) 50 | 51 | return y_chance 52 | 53 | 54 | def evaluate(self): 55 | self.logits = np.concatenate(self.logits) 56 | self.target = np.concatenate(self.target).astype('int32') 57 | 58 | self.probs = softmax(self.logits,axis=1) 59 | self.argmax = np.argmax(self.probs,axis=1) 60 | 61 | CONF = np.array(confusion_matrix(y_true=self.target,y_pred=self.argmax)) 62 | F1 = Statistics.F1(CONF) 63 | KPS = Statistics.Kappa(CONF) 64 | AUROC = roc_auc_score(y_true=Statistics.idx2onehot(self.target),y_score=self.probs,average=None) 65 | AUPRC = average_precision_score(y_true=Statistics.idx2onehot(self.target),y_score=self.probs,average=None) 66 | AUPRC_chance = self.random_auprc(self.target) 67 | 68 | print(CONF) 69 | print(F1) 70 | print(KPS) 71 | print(AUROC,np.mean(AUROC)) 72 | print(AUPRC,np.mean(AUPRC)) 73 | print(AUPRC_chance) 74 | 75 | self.reset() -------------------------------------------------------------------------------- /main_training_and_reviewer_testing.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Example code 1 # 3 | # Train CNN GRU model on dataset from two reviewers # 4 | # Test model on dataset from third reviewer # 5 | # This should be done for each reviewer -> cross validation of results # 6 | ############################################################################### 7 | 8 | import torch 9 | from model import * 10 | from dataset import * 11 | from statistics import * 12 | from torch.utils.data import DataLoader 13 | 14 | # create training and validation dataset 15 | # split_reviewer(reviewer_id) function split dataset by reviewers 16 | # results should be evaluated for all reviewers i.e. 1,2,3 17 | dataset_fnusa_train,dataset_fnusa_valid = Dataset('./DATASET_FNUSA/').split_reviewer(1) 18 | 19 | NWORKERS = 24 20 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | 22 | TRAIN = DataLoader(dataset=dataset_fnusa_train, 23 | batch_size=32, 24 | shuffle=True, 25 | drop_last=False, 26 | num_workers=NWORKERS) 27 | 28 | VALID = DataLoader(dataset=dataset_fnusa_valid, 29 | batch_size=32, 30 | shuffle=True, 31 | drop_last=False, 32 | num_workers=NWORKERS) 33 | 34 | 35 | 36 | 37 | if __name__ == "__main__": 38 | model = NN().to(DEVICE) 39 | optimizer = optim.Adam(model.parameters(),lr=1e-3,weight_decay=1e-4) 40 | loss = nn.CrossEntropyLoss() 41 | statistics = Statistics() 42 | 43 | for epoch in range(5): 44 | model.train() 45 | for i,(x,t) in enumerate(TRAIN): 46 | optimizer.zero_grad() 47 | x = x.to(DEVICE).float() 48 | t = t.to(DEVICE).long() 49 | y = model(x) 50 | J = loss(input=y[:,-1,:],target=t) 51 | J.backward() 52 | optimizer.step() 53 | 54 | if i%50==0: 55 | print('EPOCH:{}\tITER:{}\tLOSS:{}'.format(str(epoch).zfill(2), 56 | str(i).zfill(5), 57 | J.data.cpu().numpy())) 58 | 59 | # evaluate results for validation test 60 | model.eval() 61 | for i,(x,t) in enumerate(VALID): 62 | x = x.to(DEVICE).float() 63 | t = t.to(DEVICE).long() 64 | y = model(x) 65 | statistics.append(target=t,logits=y[:,-1,:]) 66 | statistics.evaluate() 67 | 68 | 69 | -------------------------------------------------------------------------------- /main_training_and_institution_testing.py: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Example code 2 # 3 | # Train CNN GRU model on dataset from one hospital and test on second hospital # 4 | # This should be done for each hospital -> cross validation of results # 5 | ################################################################################## 6 | 7 | import torch 8 | from model import * 9 | from dataset import * 10 | from statistics import * 11 | from torch.utils.data import DataLoader 12 | 13 | 14 | # Create training and testing datasets 15 | # remove powerline noise class since Europe and USA use different powerline frequencies(50Hz and 60Hz respectively) 16 | # network is not able to generalize to data from different class which was not used in training set 17 | # if additional hospital is used with same powerline frequency as training set 18 | # then powerline noise class should not be removed 19 | dataset_fnusa_train= Dataset('./DATASET_FNUSA/').remove_powerline_noise_class() 20 | dataset_mayo_test = Dataset('./DATASET_MAYO/').remove_powerline_noise_class() 21 | 22 | NWORKERS = 24 23 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 24 | 25 | TRAIN = DataLoader(dataset=dataset_fnusa_train, 26 | batch_size=32, 27 | shuffle=True, 28 | drop_last=False, 29 | num_workers=NWORKERS) 30 | 31 | TEST = DataLoader(dataset=dataset_mayo_test, 32 | batch_size=32, 33 | shuffle=True, 34 | drop_last=False, 35 | num_workers=NWORKERS) 36 | 37 | 38 | 39 | 40 | if __name__ == "__main__": 41 | # we training model for classification to 3 classes i.e. normal(physiological), noise, pathological 42 | model = NN(NOUT=3).to(DEVICE) 43 | optimizer = optim.Adam(model.parameters(),lr=1e-3,weight_decay=1e-4) 44 | loss = nn.CrossEntropyLoss() 45 | statistics = Statistics() 46 | 47 | for epoch in range(5): 48 | model.train() 49 | for i,(x,t) in enumerate(TRAIN): 50 | optimizer.zero_grad() 51 | x = x.to(DEVICE).float() 52 | t = t.to(DEVICE).long() 53 | y = model(x) 54 | J = loss(input=y[:,-1,:],target=t) 55 | J.backward() 56 | optimizer.step() 57 | 58 | if i%50==0: 59 | print('EPOCH:{}\tITER:{}\tLOSS:{}'.format(str(epoch).zfill(2), 60 | str(i).zfill(5), 61 | J.data.cpu().numpy())) 62 | 63 | model.eval() 64 | for i,(x,t) in enumerate(TEST): 65 | x = x.to(DEVICE).float() 66 | t = t.to(DEVICE).long() 67 | y = model(x) 68 | statistics.append(target=t,logits=y[:,-1,:]) 69 | statistics.evaluate() 70 | 71 | 72 | --------------------------------------------------------------------------------