├── deepviz ├── __init__.pyc ├── saliency.pyc ├── images │ ├── cat_dog.png │ ├── doberman.png │ ├── cat_dog_viz.png │ └── doberman_viz.png ├── guided_backprop.pyc ├── visual_backprop.pyc ├── utils.py ├── README.md ├── integrated_gradients.py ├── visual_backprop.py ├── saliency.py └── guided_backprop.py ├── a7503a85-c8f9-44b0-8c92-c440ab2b91de ├── README.md ├── segmenting.py ├── subjective_training.py ├── visualizations.py ├── models.py ├── intersubjective_training.py └── src.py /deepviz/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/__init__.pyc -------------------------------------------------------------------------------- /deepviz/saliency.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/saliency.pyc -------------------------------------------------------------------------------- /deepviz/images/cat_dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/images/cat_dog.png -------------------------------------------------------------------------------- /deepviz/guided_backprop.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/guided_backprop.pyc -------------------------------------------------------------------------------- /deepviz/images/doberman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/images/doberman.png -------------------------------------------------------------------------------- /deepviz/visual_backprop.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/visual_backprop.pyc -------------------------------------------------------------------------------- /deepviz/images/cat_dog_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/images/cat_dog_viz.png -------------------------------------------------------------------------------- /deepviz/images/doberman_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/images/doberman_viz.png -------------------------------------------------------------------------------- /a7503a85-c8f9-44b0-8c92-c440ab2b91de: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/a7503a85-c8f9-44b0-8c92-c440ab2b91de -------------------------------------------------------------------------------- /deepviz/utils.py: -------------------------------------------------------------------------------- 1 | def show_image(image, grayscale = True, ax=None, title=''): 2 | if ax is None: 3 | plt.figure() 4 | plt.axis('off') 5 | 6 | if len(image.shape) == 2 or grayscale == True: 7 | if len(image.shape) == 3: 8 | image = np.sum(np.abs(image), axis=2) 9 | 10 | vmax = np.percentile(image, 99) 11 | vmin = np.min(image) 12 | 13 | plt.imshow(image, cmap=plt.cm.gray, vmin=vmin, vmax=vmax) 14 | plt.title(title) 15 | else: 16 | image = image + 127.5 17 | image = image.astype('uint8') 18 | 19 | plt.imshow(image) 20 | plt.title(title) 21 | 22 | def load_image(file_path): 23 | im = PIL.Image.open(file_path) 24 | im = np.asarray(im) 25 | 26 | return im - 127.5 -------------------------------------------------------------------------------- /deepviz/README.md: -------------------------------------------------------------------------------- 1 | 2 | This repository contains the implementations in Keras of various methods to understand the prediction by a Convolutional Neural Networks. Implemented methods are: 3 | 4 | * Vanila gradient [https://arxiv.org/abs/1312.6034] 5 | * Guided backprop [https://arxiv.org/abs/1412.6806] 6 | * Integrated gradient [https://arxiv.org/abs/1703.01365] 7 | * Visual backprop [https://arxiv.org/abs/1611.05418] 8 | 9 | Each of them is accompanied with the corresponding smoothgrad version [https://arxiv.org/abs/1706.03825], which improves on any baseline method by adding random noise. 10 | 11 | Courtesy of https://github.com/tensorflow/saliency and https://github.com/mbojarski/VisualBackProp. 12 | 13 | # Examples 14 | 15 | * Dog 16 | 17 | 18 | 19 | * Dog and Cat 20 | 21 | 22 | 23 | 24 | # Usage 25 | 26 | cd deep-viz-keras 27 | 28 | ```python 29 | from guided_backprop import GuidedBackprop 30 | from utils import * 31 | from keras.applications.vgg16 import VGG16 32 | 33 | # Load the pretrained VGG16 model and make the guided backprop operator 34 | vgg16_model = VGG16(weights='imagenet') 35 | vgg16_model.compile(loss='categorical_crossentropy', optimizer='adam') 36 | guided_bprop = GuidedBackprop(vgg16_model) 37 | 38 | # Load the image and compute the guided gradient 39 | image = load_image('/path/to/image') 40 | mask = guided_bprop.get_mask(image) # compute the gradients 41 | show_image(mask) # display the grayscaled mask 42 | ``` 43 | 44 | The examples.ipynb contains the demos of all implemented methods using the built-in VGG16 model of Keras. 45 | -------------------------------------------------------------------------------- /deepviz/integrated_gradients.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities to compute an IntegratedGradients SaliencyMask.""" 16 | 17 | import numpy as np 18 | from saliency import GradientSaliency 19 | 20 | class IntegratedGradients(GradientSaliency): 21 | """A SaliencyMask class that implements the integrated gradients method. 22 | 23 | https://arxiv.org/abs/1703.01365 24 | """ 25 | 26 | def GetMask(self, input_image, input_baseline=None, nsamples=100): 27 | """Returns a integrated gradients mask.""" 28 | if input_baseline == None: 29 | input_baseline = np.zeros_like(input_image) 30 | 31 | assert input_baseline.shape == input_image.shape 32 | 33 | input_diff = input_image - input_baseline 34 | 35 | total_gradients = np.zeros_like(input_image) 36 | 37 | for alpha in np.linspace(0, 1, nsamples): 38 | input_step = input_baseline + alpha * input_diff 39 | total_gradients += super(IntegratedGradients, self).get_mast(input_step) 40 | 41 | return total_gradients * input_diff 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # P300 Brain Computer Interface (BCI) 2 | Brain Computer Interface project in which I use wide and deep convolutional neural networks to decode P300 component in EEG signals for a speller application. I also use saliency maps for visualizing the features learned by the model. These maps are then quantified to reveal the task-related brain dynamics. 3 |

4 |
5 |

6 | 7 | #### Related Publications 8 | [Abstract](http://doi.org/10.12751/nncn.bc2018.0092) and [Poster](http://bit.ly/2PdqhcG) published in Bernstein conference for computational neuroscience in Berlin in Septemper 2018. 9 | 10 | [Paper](https://doi.org/10.1088/1741-2552/ab3bb4) Convolutional Neural Networks for Decoding of Covert Attention Focus and Saliency Maps for EEG Feature Visualization. 11 | 12 | For Citations. 13 | 14 | `@Article{Farahat_2019, 15 | Title = {Convolutional neural networks for decoding of covert attention focus and saliency maps for {EEG} feature visualization}, 16 | Author = {Amr Farahat and Christoph Reichert and Catherine M Sweeney-Reed and Hermann Hinrichs}, 17 | Journal = {Journal of Neural Engineering}, 18 | Year = {2019}, 19 | Month = {oct}, 20 | Number = {6}, 21 | Pages = {066010}, 22 | Volume = {16}, 23 | Doi = {10.1088/1741-2552/ab3bb4}, 24 | Publisher = {{IOP} Publishing}, 25 | Url = {https://doi.org/10.1088%2F1741-2552%2Fab3bb4} 26 | } 27 | ` 28 | 29 | 30 | [Preprint](https://www.biorxiv.org/content/10.1101/614784v1) 31 | 32 | 33 | -------------------------------------------------------------------------------- /segmenting.py: -------------------------------------------------------------------------------- 1 | ## Reading the EEG data from matlab .mat files and segmenting them to create the training examples. 2 | 3 | from scipy.io import loadmat 4 | import numpy as np 5 | import os 6 | 7 | # create list of files. one file per subject data. 8 | path = './data/p300/resampled_50/' 9 | files = [f for f in os.listdir(path) if f.endswith('.mat')] 10 | 11 | for fname in files: 12 | 13 | print 'working on file', fname 14 | data = loadmat(path+fname) 15 | sampling_rate = 5.086299945003593e+02/10 16 | # creating variables and organizing axeses 17 | # EEG data 18 | eeg = data['eeg'] 19 | eeg = np.swapaxes(eeg, 0,2) 20 | eeg = np.swapaxes(eeg, 1, 2) 21 | # the attended targets 22 | target = data['target'] 23 | target = np.reshape(target, (target.shape[0],)) 24 | # the stimuli 25 | trigger = data['trigger'] 26 | trigger = np.swapaxes(trigger, 0,1) 27 | # number of trials 28 | trials = eeg.shape[0] 29 | # choosing the pre and post stimulus interval for each segment 30 | prev = int(round(sampling_rate*0.1)) 31 | aft = int(round(sampling_rate*0.7)) 32 | segments = np.empty(shape=(trials*60,30,prev+aft)) 33 | labels = [] 34 | triggers = [] 35 | # looping through all the trials of the subjects and segment them. There are 60 segments per trial. 36 | for k in range(trials): 37 | indices = np.nonzero(trigger[k]) 38 | trigs = trigger[k,indices] 39 | indices = np.round(indices[0]/10.0).astype(int) 40 | for j,i in enumerate(indices[0:60]): 41 | segment = eeg[k,:,i-prev:i+aft] 42 | segments[k*60+j] = segment 43 | if trigs[0,j] == target[k]: 44 | labels.append(1) 45 | else: 46 | labels.append(0) 47 | triggers.append(trigs[0,j]) 48 | labels = np.array(labels) 49 | triggers = np.array(triggers) 50 | 51 | print segments.shape, labels.shape, triggers.shape, target.shape 52 | 53 | # create new directory for the subject and save tje arrays. 54 | subject = os.path.splitext(fname)[0].split('_')[1] 55 | os.makedirs('./data/50/subjects/'+subject) 56 | # save the arrays to files to be used agian for modelling. 57 | np.save(open('./data/50/subjects/'+subject+'/trials.npy', 'w'), eeg) 58 | np.save(open('./data/50/subjects/'+subject+'/segments.npy', 'w'), segments) 59 | np.save(open('./data/50/subjects/'+subject+'/labels.npy', 'w'), labels) 60 | np.save(open('./data/50/subjects/'+subject+'/triggers.npy', 'w'), triggers) 61 | np.save(open('./data/50/subjects/'+subject+'/target.npy', 'w'), target) 62 | -------------------------------------------------------------------------------- /deepviz/visual_backprop.py: -------------------------------------------------------------------------------- 1 | from saliency import SaliencyMask 2 | import numpy as np 3 | import keras.backend as K 4 | from keras.layers import Input, Conv2DTranspose 5 | from keras.models import Model 6 | from keras.initializers import Ones, Zeros 7 | 8 | class VisualBackprop(SaliencyMask): 9 | """A SaliencyMask class that computes saliency masks with VisualBackprop (https://arxiv.org/abs/1611.05418). 10 | """ 11 | 12 | def __init__(self, model, output_index=0): 13 | """Constructs a VisualProp SaliencyMask.""" 14 | inps = [model.input, K.learning_phase()] # input placeholder 15 | outs = [layer.output for layer in model.layers] # all layer outputs 16 | self.forward_pass = K.function(inps, outs) # evaluation function 17 | 18 | self.model = model 19 | 20 | def get_mask(self, input_image): 21 | """Returns a VisualBackprop mask.""" 22 | x_value = np.expand_dims(input_image, axis=0) 23 | 24 | visual_bpr = None 25 | layer_outs = self.forward_pass([x_value, 0]) 26 | 27 | for i in range(len(self.model.layers)-1, -1, -1): 28 | if 'Conv2D' in str(type(self.model.layers[i])): 29 | layer = np.mean(layer_outs[i], axis=3, keepdims=True) 30 | layer = layer - np.min(layer) 31 | layer = layer/(np.max(layer)-np.min(layer)+1e-6) 32 | 33 | if visual_bpr is not None: 34 | if visual_bpr.shape != layer.shape: 35 | visual_bpr = self._deconv(visual_bpr) 36 | visual_bpr = visual_bpr * layer 37 | else: 38 | visual_bpr = layer 39 | 40 | return visual_bpr[0] 41 | 42 | def _deconv(self, feature_map): 43 | """The deconvolution operation to upsample the average feature map downstream""" 44 | x = Input(shape=(None, None, 1)) 45 | y = Conv2DTranspose(filters=1, 46 | kernel_size=(3,3), 47 | strides=(2,2), 48 | padding='same', 49 | kernel_initializer=Ones(), 50 | bias_initializer=Zeros())(x) 51 | 52 | deconv_model = Model(inputs=[x], outputs=[y]) 53 | 54 | inps = [deconv_model.input, K.learning_phase()] # input placeholder 55 | outs = [deconv_model.layers[-1].output] # output placeholder 56 | deconv_func = K.function(inps, outs) # evaluation function 57 | 58 | return deconv_func([feature_map, 0])[0] -------------------------------------------------------------------------------- /deepviz/saliency.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities to compute SaliencyMasks.""" 16 | import numpy as np 17 | import keras.backend as K 18 | 19 | class SaliencyMask(object): 20 | """Base class for saliency masks. Alone, this class doesn't do anything.""" 21 | def __init__(self, model, output_index=0): 22 | """Constructs a SaliencyMask. 23 | 24 | Args: 25 | model: the keras model used to make prediction 26 | output_index: the index of the node in the last layer to take derivative on 27 | """ 28 | pass 29 | 30 | def get_mask(self, input_image): 31 | """Returns an unsmoothed mask. 32 | 33 | Args: 34 | input_image: input image with shape (H, W, 3). 35 | """ 36 | pass 37 | 38 | def get_smoothed_mask(self, input_image, stdev_spread=.2, nsamples=50): 39 | """Returns a mask that is smoothed with the SmoothGrad method. 40 | 41 | Args: 42 | input_image: input image with shape (H, W, 3). 43 | """ 44 | stdev = stdev_spread * (np.max(input_image) - np.min(input_image)) 45 | 46 | total_gradients = np.zeros_like(input_image) 47 | for i in range(nsamples): 48 | noise = np.random.normal(0, stdev, input_image.shape) 49 | x_value_plus_noise = input_image + noise 50 | 51 | total_gradients += self.get_mask(x_value_plus_noise) 52 | 53 | return total_gradients / nsamples 54 | 55 | class GradientSaliency(SaliencyMask): 56 | r"""A SaliencyMask class that computes saliency masks with a gradient.""" 57 | 58 | def __init__(self, model, output_index=0): 59 | # Define the function to compute the gradient 60 | input_tensors = [model.input, # placeholder for input image tensor 61 | K.learning_phase(), # placeholder for mode (train or test) tense 62 | ] 63 | gradients = model.optimizer.get_gradients(model.output[0][output_index], model.input) 64 | self.compute_gradients = K.function(inputs=input_tensors, outputs=gradients) 65 | 66 | def get_mask(self, input_image): 67 | """Returns a vanilla gradient mask. 68 | 69 | Args: 70 | input_image: input image with shape (H, W, 3). 71 | """ 72 | 73 | # Execute the function to compute the gradient 74 | x_value = np.expand_dims(input_image, axis=0) 75 | gradients = self.compute_gradients([x_value, 0])[0][0] 76 | 77 | return gradients 78 | -------------------------------------------------------------------------------- /deepviz/guided_backprop.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilites to computed GuidedBackprop SaliencyMasks""" 16 | 17 | from saliency import SaliencyMask 18 | import numpy as np 19 | import tensorflow as tf 20 | import keras.backend as K 21 | from keras.models import load_model 22 | 23 | class GuidedBackprop(SaliencyMask): 24 | """A SaliencyMask class that computes saliency masks with GuidedBackProp. 25 | 26 | This implementation copies the TensorFlow graph to a new graph with the ReLU 27 | gradient overwritten as in the paper: 28 | https://arxiv.org/abs/1412.6806 29 | """ 30 | 31 | GuidedReluRegistered = False 32 | 33 | def __init__(self, model, output_index=0, custom_loss=None): 34 | """Constructs a GuidedBackprop SaliencyMask.""" 35 | 36 | if GuidedBackprop.GuidedReluRegistered is False: 37 | @tf.RegisterGradient("GuidedRelu") 38 | def _GuidedReluGrad(op, grad): 39 | gate_g = tf.cast(grad > 0, "float32") 40 | gate_y = tf.cast(op.outputs[0] > 0, "float32") 41 | return gate_y * gate_g * grad 42 | GuidedBackprop.GuidedReluRegistered = True 43 | 44 | """ 45 | Create a dummy session to set the learning phase to 0 (test mode in keras) without 46 | inteferring with the session in the original keras model. This is a workaround 47 | for the problem that tf.gradients returns error with keras models that contains 48 | Dropout or BatchNormalization. 49 | 50 | Basic Idea: save keras model => create new keras model with learning phase set to 0 => save 51 | the tensorflow graph => create new tensorflow graph with ReLU replaced by GuiededReLU. 52 | """ 53 | model.save('/tmp/gb_keras.h5') 54 | with tf.Graph().as_default(): 55 | with tf.Session().as_default(): 56 | K.set_learning_phase(0) 57 | load_model('/tmp/gb_keras.h5', custom_objects={"custom_loss":custom_loss}) 58 | session = K.get_session() 59 | tf.train.export_meta_graph() 60 | 61 | saver = tf.train.Saver() 62 | saver.save(session, '/tmp/guided_backprop_ckpt') 63 | 64 | self.guided_graph = tf.Graph() 65 | with self.guided_graph.as_default(): 66 | self.guided_sess = tf.Session(graph = self.guided_graph) 67 | 68 | with self.guided_graph.gradient_override_map({'Relu': 'GuidedRelu'}): 69 | saver = tf.train.import_meta_graph('/tmp/guided_backprop_ckpt.meta') 70 | saver.restore(self.guided_sess, '/tmp/guided_backprop_ckpt') 71 | 72 | self.imported_y = self.guided_graph.get_tensor_by_name(model.output.name)[0][output_index] 73 | self.imported_x = self.guided_graph.get_tensor_by_name(model.input.name) 74 | 75 | self.guided_grads_node = tf.gradients(self.imported_y, self.imported_x) 76 | 77 | def get_mask(self, input_image): 78 | """Returns a GuidedBackprop mask.""" 79 | x_value = np.expand_dims(input_image, axis=0) 80 | guided_feed_dict = {} 81 | guided_feed_dict[self.imported_x] = x_value 82 | 83 | gradients = self.guided_sess.run(self.guided_grads_node, feed_dict = guided_feed_dict)[0][0] 84 | 85 | return gradients -------------------------------------------------------------------------------- /subjective_training.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from src import load_data, preprocessing, cross_validator, save_resutls 4 | 5 | import matplotlib.pyplot as plt 6 | from os import listdir 7 | from os.path import isfile, join 8 | from utils import plot_confusion_matrix 9 | import sys 10 | import pdb 11 | #%% 12 | 13 | eeglabel = ['Fz','Cz','Pz','Oz','Iz','Fp1','Fp2','F3','F4','F7','F8','T7','T8','C3','C4','P3','P4','O9','O10','P7', 14 | 'P8','FC1','FC2','CP1','CP2','PO3','PO4','PO7','PO8','LMAST'] 15 | 16 | subjects = [sys.argv[1]] 17 | #subjects = ['kq84'] 18 | 19 | batch_size = 64 20 | lr = 0.001 21 | early_stopping=True 22 | patience=20 23 | epochs=500 24 | model_config={'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'} 25 | 26 | 27 | 28 | datasets = ['50_avg','250'] 29 | 30 | 31 | for dataset in datasets: 32 | model_names = ['deep_subjective_branched_'+str(dataset)+'_thesis1', 33 | 'deep_subjective_eegnet_'+str(dataset)+'_thesis1', 34 | 'deep_subjective_cnn_'+str(dataset)+'_thesis1', 35 | 'lda_subjective_shrinkage_'+str(dataset)+'_thesis1', 36 | 'lda_subjective_'+str(dataset)+'_thesis1', 37 | 'deep_subjective_branched_no_bn_'+str(dataset)+'_thesis1', 38 | 'deep_subjective_branched_no_dropout_'+str(dataset)+'_thesis1', 39 | 'deep_subjective_branched_no_branched_'+str(dataset)+'_thesis1', 40 | 'deep_subjective_branched_no_deep_'+str(dataset)+'_thesis1', 41 | 'deep_subjective_branched_relu_'+str(dataset)+'_thesis1', 42 | 'deep_subjective_branched_elu_'+str(dataset)+'_thesis1'] 43 | 44 | model_configs = [{'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 45 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 46 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 47 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 48 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 49 | {'bn':False, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 50 | {'bn':True, 'dropout':False, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 51 | {'bn':True, 'dropout':True, 'branched':False, 'deep':True, 'nonlinear':'tanh'}, 52 | {'bn':True, 'dropout':True, 'branched':True, 'deep':False, 'nonlinear':'tanh'}, 53 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'relu'}, 54 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'elu'}] 55 | 56 | 57 | for model_name, model_config in zip(model_names, model_configs): 58 | for subject in subjects: 59 | 60 | print 'working on subject : ', subject 61 | x, y, t, tr = load_data(subject, channels=range(29), frequency=dataset) 62 | x = preprocessing(x, frequency=dataset) 63 | metrics, histories, cnf_matrices = cross_validator((x, y, t, tr),subject, 64 | n_splits=5, epochs=epochs, 65 | batch_size=batch_size, 66 | lr=lr, 67 | model_name=model_name, 68 | model_config=model_config, 69 | early_stopping=early_stopping, 70 | patience=patience) 71 | 72 | super_final_results = save_resutls(metrics, histories, subject, suffix=model_name, early_stopping=early_stopping, 73 | patience=patience) 74 | 75 | print 'DA:', model_name,'_', subject, ' is:', super_final_results[0]['val_recognition_acc']['mean'] 76 | print 'BA:', model_name,'_', subject, ' is:', super_final_results[0]['val_balanced_acc']['mean'] 77 | print 'Recall:', model_name,'_', subject, ' is:', super_final_results[0]['val_recalls']['mean'] 78 | print 'precision:', model_name,'_', subject, ' is:', super_final_results[0]['val_precisions']['mean'] 79 | -------------------------------------------------------------------------------- /visualizations.py: -------------------------------------------------------------------------------- 1 | from src import collect_data_intersubjective, transform 2 | import matplotlib.pyplot as plt 3 | plt.switch_backend('agg') 4 | from vis.visualization import visualize_activation 5 | from vis.utils import utils 6 | from keras import activations 7 | from keras.models import load_model 8 | import numpy as np 9 | from os import listdir 10 | from os.path import isfile, join 11 | from vis.visualization import get_num_filters 12 | from deepviz.saliency import GradientSaliency 13 | from deepviz.guided_backprop import GuidedBackprop 14 | 15 | ## this script visualize the saliency maps and activation maximization maps for thr trained models. 16 | 17 | # define a function to show and/or save the maps. 18 | def show_save_image(matrix, prefix='', actions=['save']): 19 | fig, ax = plt.subplots(1,1) 20 | fig.set_size_inches(20,12) 21 | if prefix.startswith('saliency'): 22 | cax = ax.imshow(np.swapaxes(matrix, 0, 1)[:,:,0], cmap='jet', vmin=-0.015, vmax=0.015) 23 | elif prefix.startswith('activation'): 24 | cax = ax.imshow(np.swapaxes(matrix, 0, 1)[:,:,0], cmap='jet', vmin=-4, vmax=4) 25 | ax.set_yticklabels(eeglabel) 26 | ax.set_yticks(range(0,len(eeglabel))) 27 | ax.set_title(prefix) 28 | ax.set_xticklabels(range(-200,801, 100)) 29 | cbar = fig.colorbar(cax) 30 | ax.set_aspect('auto') 31 | ax.set_xlabel('time') 32 | ax.set_ylabel('channel') 33 | if 'save' in actions: 34 | fig.savefig('plots/'+prefix+'.png') 35 | if 'show' in actions: 36 | plt.show() 37 | plt.close() 38 | 39 | 40 | 41 | 42 | subjects = [name for name in listdir("./data/50/subjects/")] 43 | path = './models/subjects/' 44 | # defining channels names 45 | eeglabel = ['Fz','Cz','Pz','Oz','Iz','Fp1','Fp2','F3','F4','F7','F8','T7','T8','C3','C4','P3','P4','O9','O10','P7', 46 | 'P8','FC1','FC2','CP1','CP2','PO3','PO4','PO7','PO8','LMAST'] 47 | # defining the experiments that need to be visualized to get their models from the models directory. 48 | trial_names = ['intersubjective_eegnet', 'intersubjective_ft_10eegnet', 'intersubjective_ft_20eegnet', 'intersubjective_ft_30eegnet', 49 | 'intersubjective_ft_50eegnet', 'intersubjective_ft_70eegnet'] 50 | 51 | all_grads = np.empty((len(subjects),41,30,1)) 52 | 53 | for n, test_subject in enumerate(subjects): 54 | print 'WORKING ON', test_subject 55 | 56 | # loading the subject data 57 | x_train, y_train, x_test, y_test, o_t_test, o_tr_test = collect_data_intersubjective(subjects, test_subject) 58 | x_train, y_train, x_test, y_test = transform((x_train, y_train, x_test, y_test)) 59 | 60 | # calculating the minimum and maximum values for activation maximization calculations 61 | min_values = np.amin(np.min(x_train, axis=1), axis=1) 62 | max_values = np.amax(np.amax(x_train, axis=1), axis=1) 63 | 64 | min_value = np.mean(min_values) 65 | max_value = np.mean(max_values) 66 | 67 | # create the maps fpr each model/experiment 68 | for trial_name in trial_names: 69 | 70 | files = [f for f in listdir(join(path, test_subject)) if isfile(join(path, test_subject, f))] 71 | myfile = [file for file in files if trial_name in file ][0] 72 | 73 | model = load_model(join(path, test_subject, myfile)) 74 | # creating the saliency maps only for the models that are not finetuned. (just a personal choice) 75 | if '_ft_' not in trial_name: 76 | attended_examples = x_test[y_test==1] 77 | subj_grads = np.empty(attended_examples.shape+(1,)) 78 | vanilla = GradientSaliency(model) 79 | # creating a saliency map for each attended example and then averag them. 80 | for ex in range(attended_examples.shape[0]): 81 | example = attended_examples[ex] 82 | example = np.expand_dims(example, axis=2) 83 | grads = vanilla.get_smoothed_mask(example) 84 | subj_grads[ex] = grads 85 | subj_grads = np.mean(subj_grads, axis=0) 86 | all_grads[n] = subj_grads 87 | show_save_image(subj_grads, prefix='saliency_maps/'+test_subject+'_'+trial_name, actions=['save']) 88 | 89 | 90 | # computing the activation maximization maps 91 | layer_idx = -1 92 | filters = np.arange(get_num_filters(model.layers[layer_idx])) 93 | for filter_idx in filters: 94 | # changing the sigmoid activation to linear. 95 | model.layers[layer_idx].activation = activations.linear 96 | model = utils.apply_modifications(model) 97 | img = visualize_activation(model, layer_idx, filter_indices=filter_idx, input_range=(min_value,max_value)) 98 | show_save_image(img, prefix='activation_maximization/'+test_subject+'_'+trial_name, actions=['save']) 99 | 100 | 101 | 102 | # saving the numerical values of the saliency maps of all subjects as it will be used later to decide the most significant 103 | #channels 104 | 105 | np.save(open('plots/saliency_maps/all_grads_eegnet.npy', 'w'), all_grads) 106 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Mar 26 13:57:10 2018 5 | 6 | @author: amr 7 | """ 8 | from keras.models import Sequential, Model 9 | from keras.layers import Dense, LSTM, GRU, Input 10 | from keras.layers import Conv1D, Conv2D, MaxPooling2D, AveragePooling2D, AveragePooling1D, GlobalAveragePooling2D 11 | from keras.layers import Activation, Dropout, Flatten, Dense, Reshape, Merge, SeparableConv2D 12 | from keras.layers import PReLU 13 | from keras.layers.merge import average, concatenate, add 14 | from keras.layers.normalization import BatchNormalization 15 | from keras.layers.wrappers import TimeDistributed, Bidirectional 16 | from keras.layers.local import LocallyConnected1D 17 | from keras.optimizers import Adam, RMSprop, SGD 18 | from keras.callbacks import Callback, TensorBoard, EarlyStopping, LearningRateScheduler, ModelCheckpoint, ReduceLROnPlateau 19 | from keras.constraints import maxnorm 20 | from keras.utils import plot_model 21 | from keras import regularizers 22 | from keras.regularizers import l2, l1 23 | import keras.backend as K 24 | 25 | def branched2(data_shape, model_config={'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, f=1): 26 | timepoints = data_shape[1] 27 | channels = data_shape[2] 28 | reg = 0.01 29 | input_data = Input(shape=(timepoints, channels, 1)) 30 | 31 | 32 | spatial_conv = Conv2D(12, (1,channels), padding='valid', kernel_regularizer=l2(reg))(input_data) 33 | if model_config['bn']: 34 | spatial_conv = BatchNormalization()(spatial_conv) 35 | spatial_conv = Activation(model_config['nonlinear'])(spatial_conv) 36 | if model_config['dropout']: 37 | spatial_conv = Dropout(0.5)(spatial_conv) 38 | 39 | if model_config['branched']: 40 | branch1 = Conv2D(4, (21*f,1), padding='valid', kernel_regularizer=l2(reg))(spatial_conv) 41 | if model_config['bn']: 42 | branch1 = BatchNormalization()(branch1) 43 | branch1 = Activation(model_config['nonlinear'])(branch1) 44 | branch1 = MaxPooling2D(pool_size=(3,1), strides=(3,1))(branch1) 45 | if model_config['dropout']: 46 | branch1 = Dropout(0.5)(branch1) 47 | 48 | 49 | branch1 = Flatten()(branch1) 50 | 51 | branch2 = Conv2D(4, (5*f,1), padding='valid', dilation_rate=(1,1), kernel_regularizer=l2(reg))(spatial_conv) 52 | if model_config['bn']: 53 | branch2 = BatchNormalization()(branch2) 54 | branch2 = Activation(model_config['nonlinear'])(branch2) 55 | branch2 = MaxPooling2D(pool_size=(3,1), strides=(3,1))(branch2) 56 | if model_config['dropout']: 57 | branch2 = Dropout(0.5)(branch2) 58 | 59 | if model_config['deep']: 60 | branch2 = Conv2D(8, (5*f,1), padding='valid', dilation_rate=(1,1), kernel_regularizer=l2(reg))(branch2) 61 | if model_config['bn']: 62 | branch2 = BatchNormalization()(branch2) 63 | branch2 = Activation(model_config['nonlinear'])(branch2) 64 | if model_config['dropout']: 65 | branch2 = Dropout(0.5)(branch2) 66 | # 67 | branch2 = Conv2D(8, (5*f,1), padding='valid', kernel_regularizer=l2(reg))(branch2) 68 | if model_config['bn']: 69 | branch2 = BatchNormalization()(branch2) 70 | branch2 = Activation(model_config['nonlinear'])(branch2) 71 | branch2 = MaxPooling2D(pool_size=(2,1), strides=(2,1))(branch2) 72 | if model_config['dropout']: 73 | branch2 = Dropout(0.5)(branch2) 74 | # 75 | branch2 = Flatten()(branch2) 76 | 77 | if model_config['branched']: 78 | merged = concatenate([branch1, branch2]) 79 | dense = Dense(1, activation='sigmoid', kernel_regularizer=l2(reg))(merged) 80 | else: 81 | dense = Dense(1, activation='sigmoid', kernel_regularizer=l2(reg))(branch2) 82 | 83 | model = Model(inputs = [input_data], outputs=[dense]) 84 | 85 | 86 | return model 87 | 88 | 89 | 90 | def create_cnn(data_shape, f=1): 91 | 92 | timepoints = data_shape[1] 93 | channels = data_shape[2] 94 | kernel_size = 3*f 95 | model = Sequential() 96 | 97 | model.add(Conv2D(4, (kernel_size,3), activation='tanh', input_shape=(timepoints, channels, 1))) 98 | model.add(BatchNormalization()) 99 | model.add(MaxPooling2D((2,2))) 100 | model.add(Dropout(0.25)) 101 | 102 | 103 | model.add(Conv2D(8, (kernel_size,3), strides=(1,1), activation='tanh', padding='valid')) 104 | model.add(BatchNormalization()) 105 | model.add(MaxPooling2D((2,2))) 106 | # model.add(Dropout(0.25)) 107 | 108 | model.add(Conv2D(16, (kernel_size,3), strides=(1,1), activation='tanh', padding='valid')) 109 | model.add(BatchNormalization()) 110 | model.add(MaxPooling2D((2,2))) 111 | # model.add(Dropout(0.25)) 112 | 113 | model.add(Flatten()) 114 | 115 | model.add(Dense(1, activation='sigmoid')) 116 | 117 | return model 118 | 119 | def create_eegnet(data_shape, f=1): 120 | 121 | timepoints = data_shape[1] 122 | channels = data_shape[2] 123 | 124 | spatial_filters = 16 125 | model = Sequential() 126 | 127 | model.add(Conv2D(spatial_filters, (1,channels), activation='relu', 128 | kernel_regularizer=regularizers.l1_l2(0.0001), input_shape=(timepoints, channels, 1))) 129 | model.add(BatchNormalization()) 130 | model.add(Dropout(0.25)) 131 | model.add(Reshape((timepoints,spatial_filters,1))) 132 | 133 | model.add(Conv2D(4, (16*f,2), strides=(1,1), activation='relu', padding='same')) 134 | model.add(BatchNormalization()) 135 | model.add(MaxPooling2D((4,2))) 136 | model.add(Dropout(0.25)) 137 | 138 | model.add(Conv2D(4, (2*f,8), strides=(1,1), activation='relu', padding='same')) 139 | model.add(BatchNormalization()) 140 | model.add(MaxPooling2D((4,2))) 141 | model.add(Dropout(0.25)) 142 | 143 | model.add(Flatten()) 144 | model.add(Dense(1, activation='sigmoid')) 145 | return model 146 | -------------------------------------------------------------------------------- /intersubjective_training.py: -------------------------------------------------------------------------------- 1 | from src import collect_data_intersubjective, resample_transform, intersubjective_training, save_resutls, intersubjective_shallow, finetune 2 | import os 3 | from os.path import isfile, join 4 | from keras.models import load_model, clone_model 5 | 6 | from keras.optimizers import SGD 7 | import sys 8 | #import csv 9 | import pdb 10 | #%% 11 | 12 | 13 | subjects = [name for name in os.listdir("./data/50/subjects/")] 14 | #test_subjects = ['ab82'] 15 | test_subjects = [sys.argv[1]] 16 | 17 | batch_size = 64 18 | lr = 0.001 19 | early_stopping = True 20 | epochs = 500 21 | patience = 20 22 | 23 | model_config={'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'} 24 | 25 | ft = False 26 | ft_mode = 'all' 27 | ft_trials = [10, 20, 30, 40, 50, 60, 70] 28 | # 29 | 30 | 31 | datasets = ['50_avg', '250'] 32 | 33 | 34 | for dataset in datasets: 35 | model_names = ['deep_intersubjective_branched_'+str(dataset)+'_thesis1', 36 | 'deep_intersubjective_eegnet_'+str(dataset)+'_thesis1', 37 | 'deep_intersubjective_cnn_'+str(dataset)+'_thesis1', 38 | 'lda_intersubjective_shrinkage_'+str(dataset)+'_thesis1', 39 | 'lda_intersubjective_'+str(dataset)+'_thesis1', 40 | 'deep_intersubjective_branched_no_bn_'+str(dataset)+'_thesis1', 41 | 'deep_intersubjective_branched_no_dropout_'+str(dataset)+'_thesis1', 42 | 'deep_intersubjective_branched_no_branched_'+str(dataset)+'_thesis1', 43 | 'deep_intersubjective_branched_no_deep_'+str(dataset)+'_thesis1', 44 | 'deep_intersubjective_branched_relu_'+str(dataset)+'_thesis1', 45 | 'deep_intersubjective_branched_elu_'+str(dataset)+'_thesis1'] 46 | 47 | model_configs = [{'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 48 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 49 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 50 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 51 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 52 | {'bn':False, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 53 | {'bn':True, 'dropout':False, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, 54 | {'bn':True, 'dropout':True, 'branched':False, 'deep':True, 'nonlinear':'tanh'}, 55 | {'bn':True, 'dropout':True, 'branched':True, 'deep':False, 'nonlinear':'tanh'}, 56 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'relu'}, 57 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'elu'}] 58 | 59 | 60 | for model_name, model_config in zip(model_names, model_configs): 61 | 62 | 63 | for test_subject in test_subjects: 64 | print 'working on subject', test_subject 65 | #test_subject = subject 66 | #collecting the data 67 | print 'Collecting data ...' 68 | x_train, y_train, x_test, y_test, o_t_test, o_tr_test = collect_data_intersubjective(subjects, 69 | test_subject, 70 | mode='eeg', 71 | channels=range(29), 72 | frequency=dataset) 73 | 74 | 75 | #Train 76 | print "Training ..." 77 | if model_name.startswith('deep'): 78 | metrics, history, cnf_matrix = intersubjective_training((x_train, y_train, x_test, y_test, o_t_test, o_tr_test), 79 | model_name, 80 | test_subject, 81 | epochs=epochs, 82 | lr=lr, 83 | batch_size=batch_size, 84 | model_config=model_config, 85 | early_stopping=early_stopping, 86 | patience=patience) 87 | super_final_results = save_resutls([metrics], [history], test_subject, 88 | suffix=model_name, 89 | early_stopping=early_stopping, 90 | patience=patience) 91 | else: 92 | metrics, history, cnf_matrix, clf = intersubjective_shallow((x_train, y_train, x_test, y_test, o_t_test, o_tr_test), 93 | model_name) 94 | super_final_results = save_resutls(metrics, history, test_subject, suffix=model_name, clf=clf) 95 | 96 | print 'DA:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_recognition_acc']['mean'] 97 | print 'BA:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_balanced_acc']['mean'] 98 | print 'Recall:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_recalls']['mean'] 99 | print 'precision:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_precisions']['mean'] 100 | 101 | if model_name.startswith('deep_intersubjective_branched') and ft: 102 | model=clone_model(history.model) 103 | weights = history.model.get_weights() 104 | 105 | for i in ft_trials: 106 | model_name_modified = 'deep_intersubjective_branched_ft_'+str(i)+'_trials_'+str(dataset)+'_thesis2' 107 | model_name_modified = model_name+'_ft_'+str(i)+'_trials' 108 | model.set_weights(weights) 109 | metrics, history, cnf_matrix = finetune(model, 110 | (x_test, y_test, o_t_test, o_tr_test), 111 | model_name_modified, 112 | test_subject, 113 | epochs=epochs, 114 | train_trials=i, 115 | mode=ft_mode, 116 | early_stopping=early_stopping, 117 | patience=patience) 118 | 119 | super_final_results = save_resutls([metrics], [history], test_subject, 120 | suffix=model_name_modified, 121 | early_stopping=early_stopping, 122 | patience=patience) 123 | print 'DA:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_recognition_acc']['mean'] 124 | print 'BA:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_balanced_acc']['mean'] 125 | print 'Recall:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_recalls']['mean'] 126 | print 'precision:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_precisions']['mean'] 127 | -------------------------------------------------------------------------------- /src.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import tensorflow as tf 4 | import random as rn 5 | np.random.seed(42) 6 | rn.seed(12345) 7 | tf.set_random_seed(1234) 8 | from models import * 9 | import matplotlib.pyplot as plt 10 | from imblearn.under_sampling import RandomUnderSampler 11 | from imblearn.over_sampling import RandomOverSampler 12 | from sklearn.model_selection import train_test_split, LeavePGroupsOut, GroupKFold, StratifiedKFold 13 | from sklearn.preprocessing import MinMaxScaler, StandardScaler 14 | from sklearn.utils.class_weight import compute_class_weight 15 | from sklearn.model_selection import StratifiedKFold, train_test_split, StratifiedShuffleSplit 16 | from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix, accuracy_score 17 | import itertools 18 | from utils import plot_confusion_matrix 19 | import pdb 20 | import os 21 | from os import listdir 22 | from os.path import isfile, join 23 | import time 24 | import math 25 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 26 | from sklearn import svm 27 | from sklearn.externals import joblib 28 | from keras.models import load_model 29 | 30 | 31 | def load_data(subject, mode='eeg', channels=[], frequency=50): 32 | data_path = 'data/'+str(frequency)+'/subjects/' 33 | if mode=='eeg': 34 | x = np.load(open(data_path+subject+'/segments.npy')) 35 | elif mode=='meg': 36 | x = np.load(open(data_path+subject+'/meg_segments.npy')) 37 | else: 38 | print 'Wrong mode. you can only choose eeg or meg' 39 | y = np.load(open(data_path+subject+'/labels.npy')) 40 | t = np.load(open(data_path+subject+'/target.npy')) 41 | tr = np.load(open(data_path+subject+'/triggers.npy')) 42 | temp_t = [] 43 | for i in t: 44 | trial_index = np.ones((60,))*i 45 | temp_t.extend(np.ones((60,))*i) 46 | t = np.array(temp_t).astype(int) 47 | 48 | if len(channels): 49 | x = x[:,channels,:] 50 | 51 | # indexes = np.random.permutation(len(tr)) 52 | # tr = tr[indexes] 53 | return x, y, t, tr 54 | def preprocessing(x, frequency='50_avg'): 55 | x = np.swapaxes(x,1,2) 56 | ## baseline correction 57 | if frequency == '50_avg': 58 | bl = 5 59 | elif frequency=='250': 60 | bl = 25 61 | # bl = int(frequency/10) 62 | corrected = np.empty_like(x) 63 | for i in range(x.shape[0]): 64 | baselines = np.mean(x[i,0:bl,:], axis=0) 65 | corrected[i] = x[i] - baselines 66 | x = corrected 67 | return x 68 | 69 | def precision(y_true, y_pred): 70 | return precision_score(y_true, y_pred) 71 | def recall(y_true, y_pred): 72 | return recall_score(y_true, y_pred) 73 | def f1(y_true, y_pred): 74 | return f1_score(y_true, y_pred) 75 | def auc(y_true, y_pred): 76 | return roc_auc_score(y_true, y_pred) 77 | def balanced_accuracy(y_true, y_pred): 78 | tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() 79 | recall_p = float(tp) / (tp + fn) 80 | recall_n = float(tn) / (tn + fp) 81 | return (recall_p + recall_n) /2.0 82 | 83 | #def recognition_accuracy(probs): 84 | # trials = int(len(probs)/60) 85 | # n_correct = 0 86 | # matrix = [] 87 | # for i in range(trials): 88 | # classes = np.zeros((12,)) 89 | # for k in range(i*60,i*60+60): 90 | # 91 | # classes[tr_test[k]-1] += probs[k] 92 | # choosen = np.argmax(classes)+1 93 | # if choosen == t_test[i*60]: 94 | # n_correct+=1 95 | # confidence = classes/float(np.sum(classes)) 96 | # row = np.array([confidence, t_test[i*60]]) 97 | # matrix.append(row) 98 | # return np.array([n_correct/float(trials), np.array(matrix)]) 99 | def recognition_accuracy(probs): 100 | # pdb.set_trace() 101 | trials = int(len(probs)/60) 102 | n_correct = np.zeros((5,)) 103 | matrix = [] 104 | for i in range(trials): 105 | classes = [[] for s in range(12)] 106 | for k in range(i*60,i*60+60): 107 | 108 | classes[tr_test[k]-1].append(probs[k]) 109 | classes = np.array(classes) 110 | # pdb.set_trace() 111 | if classes.ndim > 1: 112 | classes = np.cumsum(classes, axis=1) 113 | choosen = np.argmax(classes, axis=0)+1 114 | choices = [] 115 | for j in range(5): 116 | if choosen[j] == t_test[i*60]: 117 | n_correct[j]+=1 118 | choices.append((t_test[i*60], choosen[j])) 119 | # confidence = classes/float(np.sum(classes)) 120 | # row = np.array([confidence, t_test[i*60]]) 121 | # matrix.append(row) 122 | matrix.append(choices) 123 | return np.array([n_correct/float(trials), np.array(matrix)]) 124 | 125 | def bitperminute(p, n, t): 126 | p[p==1] = 1-np.finfo(float).eps 127 | B = np.log2(n)+p*np.log2(p)+(1-p)*np.log2((1-p)/(n-1).astype(float)) 128 | return B*(float(60)/t) 129 | 130 | class Metrics(Callback): 131 | def on_train_begin(self, logs={}): 132 | self.val_precisions = [] 133 | self.val_recalls = [] 134 | self.val_f1s = [] 135 | self.val_aucs = [] 136 | self.val_balanced_acc = [] 137 | self.val_recognition_acc = [] 138 | self.val_bpm = [] 139 | self.test_acc = [] 140 | self.test_loss = [] 141 | def on_epoch_end(self, epoch, logs={}): 142 | # probs = self.model.predict(self.validation_data[0]) 143 | # y_pred = np.round(probs) 144 | # y_true = self.validation_data[1] 145 | probs = self.model.predict(x_test) 146 | probs = probs.ravel() 147 | y_pred = np.round(probs) 148 | y_true = y_test 149 | self.test_acc.append(accuracy_score(y_true, y_pred)) 150 | self.test_loss.append(self.model.evaluate(x_test, y_test, batch_size = 64, verbose=0)[0]) 151 | self.val_precisions.append(precision(y_true, y_pred)) 152 | self.val_recalls.append(recall(y_true, y_pred)) 153 | self.val_f1s.append(f1(y_true, y_pred)) 154 | self.val_aucs.append(auc(y_true, probs)) 155 | self.val_balanced_acc.append(balanced_accuracy(y_true, y_pred)) 156 | self.val_recognition_acc.append(recognition_accuracy(probs)) 157 | self.val_bpm.append(bitperminute(self.val_recognition_acc[-1][0], np.ones((5,))*12,np.arange(2,12,2))) 158 | return 159 | 160 | 161 | def step_decay(epoch): 162 | initial_lrate = 0.001 163 | drop = 0.5 164 | epochs_drop = 10.0 165 | lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop)) 166 | # if epoch: 167 | # lrate = initial_lrate/np.sqrt(epoch) 168 | # else: 169 | # return initial_lrate 170 | return lrate 171 | 172 | def cv_splitter(x, n_splits=5): 173 | 174 | n_segments = x.shape[0] 175 | n_trials = int(n_segments/60) 176 | groups = [] 177 | for i in range(n_trials): 178 | trial_index = np.ones((60,))*i 179 | groups.extend(trial_index) 180 | groups = np.array(groups) 181 | 182 | window = int(np.round(n_trials/float(n_splits))) 183 | 184 | intervals=[] 185 | for i in range(0,n_trials,window): 186 | intervals.append(range(i,np.minimum(i+window, n_trials))) 187 | 188 | if len(intervals[-1]) 4: 544 | x_train, y_train, x_test, y_test, x_valid, y_valid = data 545 | else: 546 | x_train, y_train, x_test, y_test = data 547 | 548 | 549 | # standarization of the data 550 | # computing the mean and std on the training data 551 | # scalar = StandardScaler(with_mean=False) 552 | # stds = [] 553 | # trials_no = x_train.shape[0] 554 | # for i in range(trials_no): 555 | # scalar.fit(x_train[i]) 556 | # std = scalar.scale_ 557 | # stds.append(std) 558 | # 559 | # scalar.scale_ = np.mean(stds, axis=0) 560 | # normalized_x_train = np.empty_like(x_train) 561 | # for i in range(trials_no): 562 | # temp = scalar.transform(x_train[i]) 563 | # normalized_x_train[i] = temp 564 | # 565 | # # transforming the test data 566 | # normalized_x_test = np.empty_like(x_test) 567 | # trials_no = x_test.shape[0] 568 | # for i in range(trials_no): 569 | # temp = scalar.transform(x_test[i]) 570 | # normalized_x_test[i] = temp 571 | # 572 | # x_train = normalized_x_train 573 | # x_test = normalized_x_test 574 | 575 | 576 | scalar = StandardScaler(with_mean=True) 577 | scalar.fit(x_train.reshape(x_train.shape[0],-1)) 578 | normalized_x_train = scalar.transform(x_train.reshape(x_train.shape[0], -1)).reshape(x_train.shape) 579 | normalized_x_test = scalar.transform(x_test.reshape(x_test.shape[0], -1)).reshape(x_test.shape) 580 | if len(data)>4: 581 | normalized_x_valid = scalar.transform(x_valid.reshape(x_valid.shape[0], -1)).reshape(x_valid.shape) 582 | x_train = normalized_x_train 583 | x_test = normalized_x_test 584 | if len(data) > 4: 585 | x_valid = normalized_x_valid 586 | 587 | if resample: 588 | #resampling the data 589 | n_samples, timepoints, channels = x_train.shape 590 | x_train = np.reshape(x_train, (n_samples, timepoints * channels)) 591 | ros = RandomOverSampler(random_state=0) 592 | x_res, y_res = ros.fit_sample(x_train, y_train) 593 | x_train = np.reshape(x_res, (x_res.shape[0], timepoints, channels)) 594 | y_train = y_res 595 | 596 | x_train = np.expand_dims(x_train, axis=3) 597 | x_test = np.expand_dims(x_test, axis=3) 598 | if len(data) > 4: 599 | x_valid = np.expand_dims(x_valid, axis=3) 600 | return x_train, y_train, x_test, y_test, x_valid, y_valid 601 | return x_train, y_train, x_test, y_test 602 | 603 | def transform(data): 604 | 605 | x_train, y_train, x_test, y_test = data 606 | 607 | 608 | # standarization of the data 609 | # computing the mean and std on the training data 610 | scalar = StandardScaler(with_mean=False) 611 | stds = [] 612 | trials_no = x_train.shape[0] 613 | for i in range(trials_no): 614 | scalar.fit(x_train[i]) 615 | std = scalar.scale_ 616 | stds.append(std) 617 | 618 | scalar.scale_ = np.mean(stds, axis=0) 619 | normalized_x_train = np.empty_like(x_train) 620 | for i in range(trials_no): 621 | temp = scalar.transform(x_train[i]) 622 | normalized_x_train[i] = temp 623 | 624 | # transforming the test data 625 | normalized_x_test = np.empty_like(x_test) 626 | trials_no = x_test.shape[0] 627 | for i in range(trials_no): 628 | temp = scalar.transform(x_test[i]) 629 | normalized_x_test[i] = temp 630 | 631 | x_train = normalized_x_train 632 | x_test = normalized_x_test 633 | 634 | return x_train, y_train, x_test, y_test 635 | def intersubjective_shallow(data, model_name): 636 | 637 | x_train, y_train, x_test, y_test, o_t_test, o_tr_test = data 638 | 639 | x_train, y_train, x_test, y_test = resample_transform((x_train, y_train, x_test, y_test), resample=False) 640 | 641 | global t_test 642 | t_test = o_t_test 643 | global tr_test 644 | tr_test = o_tr_test 645 | x_train = x_train.reshape(x_train.shape[0], -1) 646 | x_test = x_test.reshape(x_test.shape[0], -1) 647 | m = ['acc', 'val_acc', 'val_precisions', 'val_recalls', 'val_f1s', 'val_aucs', 648 | 'val_balanced_acc', 'val_recognition_acc', 'val_bpm'] 649 | metrics = {key: [] for key in m } 650 | history = False 651 | if 'svm' in model_name: 652 | clf = svm.LinearSVC(random_state = 0) 653 | elif 'lda' in model_name: 654 | if 'shrinkage' in model_name: 655 | clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto') 656 | else: 657 | clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage=None) 658 | clf.fit(x_train, y_train) 659 | y_predict = clf.predict(x_test) 660 | probs = clf.decision_function(x_test) 661 | metrics['acc'].append(clf.score(x_train, y_train)) 662 | metrics = compute_metrics(metrics, probs, y_predict, y_test) 663 | cnf_matrix = confusion_matrix(y_test, y_predict) 664 | 665 | return metrics, history, cnf_matrix, clf 666 | 667 | def intersubjective_training(data,model_name, subject, epochs=5, lr=0.001, 668 | batch_size=128, 669 | model_config={'bn':True, 'dropout':True, 'branched':True, 'nonlinear':'tanh'}, 670 | early_stopping=True, patience=10): 671 | 672 | global y_test 673 | x_tv, y_tv, x_test, y_test, o_t_test, o_tr_test = data 674 | 675 | if early_stopping: 676 | # pdb.set_trace() 677 | x_train, x_valid, y_train, y_valid = train_test_split(x_tv, y_tv, 678 | stratify=y_tv, 679 | test_size=0.2) 680 | global x_test 681 | x_train, y_train, x_test, y_test, x_valid, y_valid = resample_transform((x_train, y_train, x_test, y_test, x_valid, y_valid)) 682 | else: 683 | x_train = x_tv 684 | y_train = y_tv 685 | global x_test 686 | x_train, y_train, x_test, y_test = resample_transform((x_train, y_train, x_test, y_test)) 687 | 688 | global t_test 689 | t_test = o_t_test 690 | global tr_test 691 | tr_test = o_tr_test 692 | 693 | if 'branched' in model_name: 694 | if '250' in model_name: 695 | model = branched2(x_train.shape, model_config=model_config, f=5) 696 | else: 697 | model = branched2(x_train.shape, model_config=model_config, f=1) 698 | elif 'eegnet' in model_name: 699 | if '250' in model_name: 700 | model = create_eegnet(x_train.shape, f=4) 701 | else: 702 | model = create_eegnet(x_train.shape, f=1) 703 | elif 'cnn' in model_name: 704 | if '250' in model_name: 705 | model = create_cnn(x_train.shape, f=5) 706 | else: 707 | model = create_cnn(x_train.shape, f=1) 708 | 709 | lrate = LearningRateScheduler(step_decay) 710 | adam = Adam(lr=lr) 711 | # lrate = LearningRateScheduler(step_decay) 712 | model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy']) 713 | 714 | m = Metrics() 715 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, 716 | patience=int(patience/2), min_lr=0) 717 | early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=patience, verbose=0, mode='auto') 718 | mod_path = './models/subjects/'+subject 719 | timestr = time.strftime("%Y%m%d-%H%M") 720 | 721 | 722 | checkpointer = ModelCheckpoint(filepath=mod_path+'/best_'+model_name+'_'+timestr, 723 | monitor='val_loss', verbose=1, save_best_only=True) 724 | if early_stopping: 725 | history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, shuffle=True, verbose=2, 726 | validation_data=(x_valid, y_valid), callbacks=[m, early_stop, checkpointer, reduce_lr]) 727 | else: 728 | history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, shuffle=True, verbose=2, 729 | validation_data=(x_test, y_test), callbacks=[m]) 730 | 731 | 732 | 733 | 734 | probabilities = model.predict(x_test, batch_size=batch_size, verbose=0) 735 | y_predict = [(round(k)) for k in probabilities] 736 | 737 | cnf_matrix = confusion_matrix(y_test, y_predict) 738 | 739 | return m, history, cnf_matrix 740 | 741 | 742 | 743 | def finetune(model, data, model_name, subject, epochs=10, train_trials=40, mode='all', early_stopping=True, patience=10): 744 | 745 | for layer in model.layers[:26]: 746 | if mode=='all': 747 | layer.trainable = True 748 | elif mode=='top': 749 | layer.trainable = False 750 | else: 751 | print 'wrong keyword argument' 752 | return 753 | # print model.summary() 754 | 755 | opt = SGD(lr=1e-4, momentum=0.9) 756 | 757 | model.compile(loss='binary_crossentropy', optimizer=opt , metrics=['accuracy']) 758 | 759 | x_test, y_test, o_t_test, o_tr_test = data 760 | 761 | segments = train_trials * 60 762 | 763 | x_tv = x_test[0:segments] 764 | y_tv = y_test[0:segments] 765 | 766 | if early_stopping: 767 | x_train, x_valid, y_train, y_valid = train_test_split(x_tv, y_tv, 768 | stratify=y_tv, 769 | test_size=0.2) 770 | x_test = x_test[segments:] 771 | global y_test 772 | y_test = y_test[segments:] 773 | global x_test 774 | x_train, y_train, x_test, y_test, x_valid, y_valid = resample_transform((x_train, y_train, x_test, y_test, x_valid, y_valid)) 775 | else: 776 | x_train = x_tv 777 | y_train = y_tv 778 | x_test = x_test[segments:] 779 | global y_test 780 | y_test = y_test[segments:] 781 | global x_test 782 | x_train, y_train, x_test, y_test = resample_transform((x_train, y_train, x_test, y_test)) 783 | 784 | # x_test = x_test[segments:] 785 | # y_test = y_test[segments:] 786 | 787 | 788 | 789 | global t_test 790 | t_test = o_t_test[segments:] 791 | global tr_test 792 | tr_test = o_tr_test[segments:] 793 | # #resampling the data 794 | # n_samples, timepoints, channels, z = x_train.shape 795 | # x_train = np.reshape(x_train, (n_samples, timepoints * channels)) 796 | # ros = RandomOverSampler(random_state=0) 797 | # x_res, y_res = ros.fit_sample(x_train, y_train) 798 | # x_train = np.reshape(x_res, (x_res.shape[0], timepoints, channels)) 799 | # y_train = y_res 800 | # 801 | # x_train = np.expand_dims(x_train, axis=3) 802 | # x_test = np.expand_dims(x_test, axis=3) 803 | 804 | early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=patience, verbose=0, mode='auto') 805 | 806 | 807 | m = Metrics() 808 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, 809 | patience=int(patience/2), min_lr=0) 810 | mod_path = './models/subjects/'+subject 811 | timestr = time.strftime("%Y%m%d-%H%M") 812 | 813 | checkpointer = ModelCheckpoint(filepath=mod_path+'/best_'+model_name+'_'+timestr, 814 | monitor='val_loss', verbose=1, save_best_only=True) 815 | if early_stopping: 816 | history = model.fit(x_train, y_train, batch_size=64, epochs=epochs, shuffle=True, verbose=2, 817 | validation_data=(x_valid, y_valid), callbacks=[m, early_stop, checkpointer, reduce_lr]) 818 | else: 819 | history = model.fit(x_train, y_train, batch_size=64, epochs=epochs, shuffle=True, verbose=2, 820 | validation_data=(x_test, y_test), callbacks=[m]) 821 | 822 | 823 | probabilities = model.predict(x_test, batch_size=128, verbose=0) 824 | y_predict = [(round(k)) for k in probabilities] 825 | 826 | cnf_matrix = confusion_matrix(y_test, y_predict) 827 | 828 | 829 | return m, history, cnf_matrix 830 | --------------------------------------------------------------------------------