├── deepviz
├── __init__.pyc
├── saliency.pyc
├── images
│ ├── cat_dog.png
│ ├── doberman.png
│ ├── cat_dog_viz.png
│ └── doberman_viz.png
├── guided_backprop.pyc
├── visual_backprop.pyc
├── utils.py
├── README.md
├── integrated_gradients.py
├── visual_backprop.py
├── saliency.py
└── guided_backprop.py
├── a7503a85-c8f9-44b0-8c92-c440ab2b91de
├── README.md
├── segmenting.py
├── subjective_training.py
├── visualizations.py
├── models.py
├── intersubjective_training.py
└── src.py
/deepviz/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/__init__.pyc
--------------------------------------------------------------------------------
/deepviz/saliency.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/saliency.pyc
--------------------------------------------------------------------------------
/deepviz/images/cat_dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/images/cat_dog.png
--------------------------------------------------------------------------------
/deepviz/guided_backprop.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/guided_backprop.pyc
--------------------------------------------------------------------------------
/deepviz/images/doberman.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/images/doberman.png
--------------------------------------------------------------------------------
/deepviz/visual_backprop.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/visual_backprop.pyc
--------------------------------------------------------------------------------
/deepviz/images/cat_dog_viz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/images/cat_dog_viz.png
--------------------------------------------------------------------------------
/deepviz/images/doberman_viz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/deepviz/images/doberman_viz.png
--------------------------------------------------------------------------------
/a7503a85-c8f9-44b0-8c92-c440ab2b91de:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amr-farahat/BCI-P300/HEAD/a7503a85-c8f9-44b0-8c92-c440ab2b91de
--------------------------------------------------------------------------------
/deepviz/utils.py:
--------------------------------------------------------------------------------
1 | def show_image(image, grayscale = True, ax=None, title=''):
2 | if ax is None:
3 | plt.figure()
4 | plt.axis('off')
5 |
6 | if len(image.shape) == 2 or grayscale == True:
7 | if len(image.shape) == 3:
8 | image = np.sum(np.abs(image), axis=2)
9 |
10 | vmax = np.percentile(image, 99)
11 | vmin = np.min(image)
12 |
13 | plt.imshow(image, cmap=plt.cm.gray, vmin=vmin, vmax=vmax)
14 | plt.title(title)
15 | else:
16 | image = image + 127.5
17 | image = image.astype('uint8')
18 |
19 | plt.imshow(image)
20 | plt.title(title)
21 |
22 | def load_image(file_path):
23 | im = PIL.Image.open(file_path)
24 | im = np.asarray(im)
25 |
26 | return im - 127.5
--------------------------------------------------------------------------------
/deepviz/README.md:
--------------------------------------------------------------------------------
1 |
2 | This repository contains the implementations in Keras of various methods to understand the prediction by a Convolutional Neural Networks. Implemented methods are:
3 |
4 | * Vanila gradient [https://arxiv.org/abs/1312.6034]
5 | * Guided backprop [https://arxiv.org/abs/1412.6806]
6 | * Integrated gradient [https://arxiv.org/abs/1703.01365]
7 | * Visual backprop [https://arxiv.org/abs/1611.05418]
8 |
9 | Each of them is accompanied with the corresponding smoothgrad version [https://arxiv.org/abs/1706.03825], which improves on any baseline method by adding random noise.
10 |
11 | Courtesy of https://github.com/tensorflow/saliency and https://github.com/mbojarski/VisualBackProp.
12 |
13 | # Examples
14 |
15 | * Dog
16 |
17 |
18 |
19 | * Dog and Cat
20 |
21 |
22 |
23 |
24 | # Usage
25 |
26 | cd deep-viz-keras
27 |
28 | ```python
29 | from guided_backprop import GuidedBackprop
30 | from utils import *
31 | from keras.applications.vgg16 import VGG16
32 |
33 | # Load the pretrained VGG16 model and make the guided backprop operator
34 | vgg16_model = VGG16(weights='imagenet')
35 | vgg16_model.compile(loss='categorical_crossentropy', optimizer='adam')
36 | guided_bprop = GuidedBackprop(vgg16_model)
37 |
38 | # Load the image and compute the guided gradient
39 | image = load_image('/path/to/image')
40 | mask = guided_bprop.get_mask(image) # compute the gradients
41 | show_image(mask) # display the grayscaled mask
42 | ```
43 |
44 | The examples.ipynb contains the demos of all implemented methods using the built-in VGG16 model of Keras.
45 |
--------------------------------------------------------------------------------
/deepviz/integrated_gradients.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities to compute an IntegratedGradients SaliencyMask."""
16 |
17 | import numpy as np
18 | from saliency import GradientSaliency
19 |
20 | class IntegratedGradients(GradientSaliency):
21 | """A SaliencyMask class that implements the integrated gradients method.
22 |
23 | https://arxiv.org/abs/1703.01365
24 | """
25 |
26 | def GetMask(self, input_image, input_baseline=None, nsamples=100):
27 | """Returns a integrated gradients mask."""
28 | if input_baseline == None:
29 | input_baseline = np.zeros_like(input_image)
30 |
31 | assert input_baseline.shape == input_image.shape
32 |
33 | input_diff = input_image - input_baseline
34 |
35 | total_gradients = np.zeros_like(input_image)
36 |
37 | for alpha in np.linspace(0, 1, nsamples):
38 | input_step = input_baseline + alpha * input_diff
39 | total_gradients += super(IntegratedGradients, self).get_mast(input_step)
40 |
41 | return total_gradients * input_diff
42 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # P300 Brain Computer Interface (BCI)
2 | Brain Computer Interface project in which I use wide and deep convolutional neural networks to decode P300 component in EEG signals for a speller application. I also use saliency maps for visualizing the features learned by the model. These maps are then quantified to reveal the task-related brain dynamics.
3 |
4 |
5 |
6 |
7 | #### Related Publications
8 | [Abstract](http://doi.org/10.12751/nncn.bc2018.0092) and [Poster](http://bit.ly/2PdqhcG) published in Bernstein conference for computational neuroscience in Berlin in Septemper 2018.
9 |
10 | [Paper](https://doi.org/10.1088/1741-2552/ab3bb4) Convolutional Neural Networks for Decoding of Covert Attention Focus and Saliency Maps for EEG Feature Visualization.
11 |
12 | For Citations.
13 |
14 | `@Article{Farahat_2019,
15 | Title = {Convolutional neural networks for decoding of covert attention focus and saliency maps for {EEG} feature visualization},
16 | Author = {Amr Farahat and Christoph Reichert and Catherine M Sweeney-Reed and Hermann Hinrichs},
17 | Journal = {Journal of Neural Engineering},
18 | Year = {2019},
19 | Month = {oct},
20 | Number = {6},
21 | Pages = {066010},
22 | Volume = {16},
23 | Doi = {10.1088/1741-2552/ab3bb4},
24 | Publisher = {{IOP} Publishing},
25 | Url = {https://doi.org/10.1088%2F1741-2552%2Fab3bb4}
26 | }
27 | `
28 |
29 |
30 | [Preprint](https://www.biorxiv.org/content/10.1101/614784v1)
31 |
32 |
33 |
--------------------------------------------------------------------------------
/segmenting.py:
--------------------------------------------------------------------------------
1 | ## Reading the EEG data from matlab .mat files and segmenting them to create the training examples.
2 |
3 | from scipy.io import loadmat
4 | import numpy as np
5 | import os
6 |
7 | # create list of files. one file per subject data.
8 | path = './data/p300/resampled_50/'
9 | files = [f for f in os.listdir(path) if f.endswith('.mat')]
10 |
11 | for fname in files:
12 |
13 | print 'working on file', fname
14 | data = loadmat(path+fname)
15 | sampling_rate = 5.086299945003593e+02/10
16 | # creating variables and organizing axeses
17 | # EEG data
18 | eeg = data['eeg']
19 | eeg = np.swapaxes(eeg, 0,2)
20 | eeg = np.swapaxes(eeg, 1, 2)
21 | # the attended targets
22 | target = data['target']
23 | target = np.reshape(target, (target.shape[0],))
24 | # the stimuli
25 | trigger = data['trigger']
26 | trigger = np.swapaxes(trigger, 0,1)
27 | # number of trials
28 | trials = eeg.shape[0]
29 | # choosing the pre and post stimulus interval for each segment
30 | prev = int(round(sampling_rate*0.1))
31 | aft = int(round(sampling_rate*0.7))
32 | segments = np.empty(shape=(trials*60,30,prev+aft))
33 | labels = []
34 | triggers = []
35 | # looping through all the trials of the subjects and segment them. There are 60 segments per trial.
36 | for k in range(trials):
37 | indices = np.nonzero(trigger[k])
38 | trigs = trigger[k,indices]
39 | indices = np.round(indices[0]/10.0).astype(int)
40 | for j,i in enumerate(indices[0:60]):
41 | segment = eeg[k,:,i-prev:i+aft]
42 | segments[k*60+j] = segment
43 | if trigs[0,j] == target[k]:
44 | labels.append(1)
45 | else:
46 | labels.append(0)
47 | triggers.append(trigs[0,j])
48 | labels = np.array(labels)
49 | triggers = np.array(triggers)
50 |
51 | print segments.shape, labels.shape, triggers.shape, target.shape
52 |
53 | # create new directory for the subject and save tje arrays.
54 | subject = os.path.splitext(fname)[0].split('_')[1]
55 | os.makedirs('./data/50/subjects/'+subject)
56 | # save the arrays to files to be used agian for modelling.
57 | np.save(open('./data/50/subjects/'+subject+'/trials.npy', 'w'), eeg)
58 | np.save(open('./data/50/subjects/'+subject+'/segments.npy', 'w'), segments)
59 | np.save(open('./data/50/subjects/'+subject+'/labels.npy', 'w'), labels)
60 | np.save(open('./data/50/subjects/'+subject+'/triggers.npy', 'w'), triggers)
61 | np.save(open('./data/50/subjects/'+subject+'/target.npy', 'w'), target)
62 |
--------------------------------------------------------------------------------
/deepviz/visual_backprop.py:
--------------------------------------------------------------------------------
1 | from saliency import SaliencyMask
2 | import numpy as np
3 | import keras.backend as K
4 | from keras.layers import Input, Conv2DTranspose
5 | from keras.models import Model
6 | from keras.initializers import Ones, Zeros
7 |
8 | class VisualBackprop(SaliencyMask):
9 | """A SaliencyMask class that computes saliency masks with VisualBackprop (https://arxiv.org/abs/1611.05418).
10 | """
11 |
12 | def __init__(self, model, output_index=0):
13 | """Constructs a VisualProp SaliencyMask."""
14 | inps = [model.input, K.learning_phase()] # input placeholder
15 | outs = [layer.output for layer in model.layers] # all layer outputs
16 | self.forward_pass = K.function(inps, outs) # evaluation function
17 |
18 | self.model = model
19 |
20 | def get_mask(self, input_image):
21 | """Returns a VisualBackprop mask."""
22 | x_value = np.expand_dims(input_image, axis=0)
23 |
24 | visual_bpr = None
25 | layer_outs = self.forward_pass([x_value, 0])
26 |
27 | for i in range(len(self.model.layers)-1, -1, -1):
28 | if 'Conv2D' in str(type(self.model.layers[i])):
29 | layer = np.mean(layer_outs[i], axis=3, keepdims=True)
30 | layer = layer - np.min(layer)
31 | layer = layer/(np.max(layer)-np.min(layer)+1e-6)
32 |
33 | if visual_bpr is not None:
34 | if visual_bpr.shape != layer.shape:
35 | visual_bpr = self._deconv(visual_bpr)
36 | visual_bpr = visual_bpr * layer
37 | else:
38 | visual_bpr = layer
39 |
40 | return visual_bpr[0]
41 |
42 | def _deconv(self, feature_map):
43 | """The deconvolution operation to upsample the average feature map downstream"""
44 | x = Input(shape=(None, None, 1))
45 | y = Conv2DTranspose(filters=1,
46 | kernel_size=(3,3),
47 | strides=(2,2),
48 | padding='same',
49 | kernel_initializer=Ones(),
50 | bias_initializer=Zeros())(x)
51 |
52 | deconv_model = Model(inputs=[x], outputs=[y])
53 |
54 | inps = [deconv_model.input, K.learning_phase()] # input placeholder
55 | outs = [deconv_model.layers[-1].output] # output placeholder
56 | deconv_func = K.function(inps, outs) # evaluation function
57 |
58 | return deconv_func([feature_map, 0])[0]
--------------------------------------------------------------------------------
/deepviz/saliency.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities to compute SaliencyMasks."""
16 | import numpy as np
17 | import keras.backend as K
18 |
19 | class SaliencyMask(object):
20 | """Base class for saliency masks. Alone, this class doesn't do anything."""
21 | def __init__(self, model, output_index=0):
22 | """Constructs a SaliencyMask.
23 |
24 | Args:
25 | model: the keras model used to make prediction
26 | output_index: the index of the node in the last layer to take derivative on
27 | """
28 | pass
29 |
30 | def get_mask(self, input_image):
31 | """Returns an unsmoothed mask.
32 |
33 | Args:
34 | input_image: input image with shape (H, W, 3).
35 | """
36 | pass
37 |
38 | def get_smoothed_mask(self, input_image, stdev_spread=.2, nsamples=50):
39 | """Returns a mask that is smoothed with the SmoothGrad method.
40 |
41 | Args:
42 | input_image: input image with shape (H, W, 3).
43 | """
44 | stdev = stdev_spread * (np.max(input_image) - np.min(input_image))
45 |
46 | total_gradients = np.zeros_like(input_image)
47 | for i in range(nsamples):
48 | noise = np.random.normal(0, stdev, input_image.shape)
49 | x_value_plus_noise = input_image + noise
50 |
51 | total_gradients += self.get_mask(x_value_plus_noise)
52 |
53 | return total_gradients / nsamples
54 |
55 | class GradientSaliency(SaliencyMask):
56 | r"""A SaliencyMask class that computes saliency masks with a gradient."""
57 |
58 | def __init__(self, model, output_index=0):
59 | # Define the function to compute the gradient
60 | input_tensors = [model.input, # placeholder for input image tensor
61 | K.learning_phase(), # placeholder for mode (train or test) tense
62 | ]
63 | gradients = model.optimizer.get_gradients(model.output[0][output_index], model.input)
64 | self.compute_gradients = K.function(inputs=input_tensors, outputs=gradients)
65 |
66 | def get_mask(self, input_image):
67 | """Returns a vanilla gradient mask.
68 |
69 | Args:
70 | input_image: input image with shape (H, W, 3).
71 | """
72 |
73 | # Execute the function to compute the gradient
74 | x_value = np.expand_dims(input_image, axis=0)
75 | gradients = self.compute_gradients([x_value, 0])[0][0]
76 |
77 | return gradients
78 |
--------------------------------------------------------------------------------
/deepviz/guided_backprop.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilites to computed GuidedBackprop SaliencyMasks"""
16 |
17 | from saliency import SaliencyMask
18 | import numpy as np
19 | import tensorflow as tf
20 | import keras.backend as K
21 | from keras.models import load_model
22 |
23 | class GuidedBackprop(SaliencyMask):
24 | """A SaliencyMask class that computes saliency masks with GuidedBackProp.
25 |
26 | This implementation copies the TensorFlow graph to a new graph with the ReLU
27 | gradient overwritten as in the paper:
28 | https://arxiv.org/abs/1412.6806
29 | """
30 |
31 | GuidedReluRegistered = False
32 |
33 | def __init__(self, model, output_index=0, custom_loss=None):
34 | """Constructs a GuidedBackprop SaliencyMask."""
35 |
36 | if GuidedBackprop.GuidedReluRegistered is False:
37 | @tf.RegisterGradient("GuidedRelu")
38 | def _GuidedReluGrad(op, grad):
39 | gate_g = tf.cast(grad > 0, "float32")
40 | gate_y = tf.cast(op.outputs[0] > 0, "float32")
41 | return gate_y * gate_g * grad
42 | GuidedBackprop.GuidedReluRegistered = True
43 |
44 | """
45 | Create a dummy session to set the learning phase to 0 (test mode in keras) without
46 | inteferring with the session in the original keras model. This is a workaround
47 | for the problem that tf.gradients returns error with keras models that contains
48 | Dropout or BatchNormalization.
49 |
50 | Basic Idea: save keras model => create new keras model with learning phase set to 0 => save
51 | the tensorflow graph => create new tensorflow graph with ReLU replaced by GuiededReLU.
52 | """
53 | model.save('/tmp/gb_keras.h5')
54 | with tf.Graph().as_default():
55 | with tf.Session().as_default():
56 | K.set_learning_phase(0)
57 | load_model('/tmp/gb_keras.h5', custom_objects={"custom_loss":custom_loss})
58 | session = K.get_session()
59 | tf.train.export_meta_graph()
60 |
61 | saver = tf.train.Saver()
62 | saver.save(session, '/tmp/guided_backprop_ckpt')
63 |
64 | self.guided_graph = tf.Graph()
65 | with self.guided_graph.as_default():
66 | self.guided_sess = tf.Session(graph = self.guided_graph)
67 |
68 | with self.guided_graph.gradient_override_map({'Relu': 'GuidedRelu'}):
69 | saver = tf.train.import_meta_graph('/tmp/guided_backprop_ckpt.meta')
70 | saver.restore(self.guided_sess, '/tmp/guided_backprop_ckpt')
71 |
72 | self.imported_y = self.guided_graph.get_tensor_by_name(model.output.name)[0][output_index]
73 | self.imported_x = self.guided_graph.get_tensor_by_name(model.input.name)
74 |
75 | self.guided_grads_node = tf.gradients(self.imported_y, self.imported_x)
76 |
77 | def get_mask(self, input_image):
78 | """Returns a GuidedBackprop mask."""
79 | x_value = np.expand_dims(input_image, axis=0)
80 | guided_feed_dict = {}
81 | guided_feed_dict[self.imported_x] = x_value
82 |
83 | gradients = self.guided_sess.run(self.guided_grads_node, feed_dict = guided_feed_dict)[0][0]
84 |
85 | return gradients
--------------------------------------------------------------------------------
/subjective_training.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from src import load_data, preprocessing, cross_validator, save_resutls
4 |
5 | import matplotlib.pyplot as plt
6 | from os import listdir
7 | from os.path import isfile, join
8 | from utils import plot_confusion_matrix
9 | import sys
10 | import pdb
11 | #%%
12 |
13 | eeglabel = ['Fz','Cz','Pz','Oz','Iz','Fp1','Fp2','F3','F4','F7','F8','T7','T8','C3','C4','P3','P4','O9','O10','P7',
14 | 'P8','FC1','FC2','CP1','CP2','PO3','PO4','PO7','PO8','LMAST']
15 |
16 | subjects = [sys.argv[1]]
17 | #subjects = ['kq84']
18 |
19 | batch_size = 64
20 | lr = 0.001
21 | early_stopping=True
22 | patience=20
23 | epochs=500
24 | model_config={'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}
25 |
26 |
27 |
28 | datasets = ['50_avg','250']
29 |
30 |
31 | for dataset in datasets:
32 | model_names = ['deep_subjective_branched_'+str(dataset)+'_thesis1',
33 | 'deep_subjective_eegnet_'+str(dataset)+'_thesis1',
34 | 'deep_subjective_cnn_'+str(dataset)+'_thesis1',
35 | 'lda_subjective_shrinkage_'+str(dataset)+'_thesis1',
36 | 'lda_subjective_'+str(dataset)+'_thesis1',
37 | 'deep_subjective_branched_no_bn_'+str(dataset)+'_thesis1',
38 | 'deep_subjective_branched_no_dropout_'+str(dataset)+'_thesis1',
39 | 'deep_subjective_branched_no_branched_'+str(dataset)+'_thesis1',
40 | 'deep_subjective_branched_no_deep_'+str(dataset)+'_thesis1',
41 | 'deep_subjective_branched_relu_'+str(dataset)+'_thesis1',
42 | 'deep_subjective_branched_elu_'+str(dataset)+'_thesis1']
43 |
44 | model_configs = [{'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
45 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
46 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
47 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
48 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
49 | {'bn':False, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
50 | {'bn':True, 'dropout':False, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
51 | {'bn':True, 'dropout':True, 'branched':False, 'deep':True, 'nonlinear':'tanh'},
52 | {'bn':True, 'dropout':True, 'branched':True, 'deep':False, 'nonlinear':'tanh'},
53 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'relu'},
54 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'elu'}]
55 |
56 |
57 | for model_name, model_config in zip(model_names, model_configs):
58 | for subject in subjects:
59 |
60 | print 'working on subject : ', subject
61 | x, y, t, tr = load_data(subject, channels=range(29), frequency=dataset)
62 | x = preprocessing(x, frequency=dataset)
63 | metrics, histories, cnf_matrices = cross_validator((x, y, t, tr),subject,
64 | n_splits=5, epochs=epochs,
65 | batch_size=batch_size,
66 | lr=lr,
67 | model_name=model_name,
68 | model_config=model_config,
69 | early_stopping=early_stopping,
70 | patience=patience)
71 |
72 | super_final_results = save_resutls(metrics, histories, subject, suffix=model_name, early_stopping=early_stopping,
73 | patience=patience)
74 |
75 | print 'DA:', model_name,'_', subject, ' is:', super_final_results[0]['val_recognition_acc']['mean']
76 | print 'BA:', model_name,'_', subject, ' is:', super_final_results[0]['val_balanced_acc']['mean']
77 | print 'Recall:', model_name,'_', subject, ' is:', super_final_results[0]['val_recalls']['mean']
78 | print 'precision:', model_name,'_', subject, ' is:', super_final_results[0]['val_precisions']['mean']
79 |
--------------------------------------------------------------------------------
/visualizations.py:
--------------------------------------------------------------------------------
1 | from src import collect_data_intersubjective, transform
2 | import matplotlib.pyplot as plt
3 | plt.switch_backend('agg')
4 | from vis.visualization import visualize_activation
5 | from vis.utils import utils
6 | from keras import activations
7 | from keras.models import load_model
8 | import numpy as np
9 | from os import listdir
10 | from os.path import isfile, join
11 | from vis.visualization import get_num_filters
12 | from deepviz.saliency import GradientSaliency
13 | from deepviz.guided_backprop import GuidedBackprop
14 |
15 | ## this script visualize the saliency maps and activation maximization maps for thr trained models.
16 |
17 | # define a function to show and/or save the maps.
18 | def show_save_image(matrix, prefix='', actions=['save']):
19 | fig, ax = plt.subplots(1,1)
20 | fig.set_size_inches(20,12)
21 | if prefix.startswith('saliency'):
22 | cax = ax.imshow(np.swapaxes(matrix, 0, 1)[:,:,0], cmap='jet', vmin=-0.015, vmax=0.015)
23 | elif prefix.startswith('activation'):
24 | cax = ax.imshow(np.swapaxes(matrix, 0, 1)[:,:,0], cmap='jet', vmin=-4, vmax=4)
25 | ax.set_yticklabels(eeglabel)
26 | ax.set_yticks(range(0,len(eeglabel)))
27 | ax.set_title(prefix)
28 | ax.set_xticklabels(range(-200,801, 100))
29 | cbar = fig.colorbar(cax)
30 | ax.set_aspect('auto')
31 | ax.set_xlabel('time')
32 | ax.set_ylabel('channel')
33 | if 'save' in actions:
34 | fig.savefig('plots/'+prefix+'.png')
35 | if 'show' in actions:
36 | plt.show()
37 | plt.close()
38 |
39 |
40 |
41 |
42 | subjects = [name for name in listdir("./data/50/subjects/")]
43 | path = './models/subjects/'
44 | # defining channels names
45 | eeglabel = ['Fz','Cz','Pz','Oz','Iz','Fp1','Fp2','F3','F4','F7','F8','T7','T8','C3','C4','P3','P4','O9','O10','P7',
46 | 'P8','FC1','FC2','CP1','CP2','PO3','PO4','PO7','PO8','LMAST']
47 | # defining the experiments that need to be visualized to get their models from the models directory.
48 | trial_names = ['intersubjective_eegnet', 'intersubjective_ft_10eegnet', 'intersubjective_ft_20eegnet', 'intersubjective_ft_30eegnet',
49 | 'intersubjective_ft_50eegnet', 'intersubjective_ft_70eegnet']
50 |
51 | all_grads = np.empty((len(subjects),41,30,1))
52 |
53 | for n, test_subject in enumerate(subjects):
54 | print 'WORKING ON', test_subject
55 |
56 | # loading the subject data
57 | x_train, y_train, x_test, y_test, o_t_test, o_tr_test = collect_data_intersubjective(subjects, test_subject)
58 | x_train, y_train, x_test, y_test = transform((x_train, y_train, x_test, y_test))
59 |
60 | # calculating the minimum and maximum values for activation maximization calculations
61 | min_values = np.amin(np.min(x_train, axis=1), axis=1)
62 | max_values = np.amax(np.amax(x_train, axis=1), axis=1)
63 |
64 | min_value = np.mean(min_values)
65 | max_value = np.mean(max_values)
66 |
67 | # create the maps fpr each model/experiment
68 | for trial_name in trial_names:
69 |
70 | files = [f for f in listdir(join(path, test_subject)) if isfile(join(path, test_subject, f))]
71 | myfile = [file for file in files if trial_name in file ][0]
72 |
73 | model = load_model(join(path, test_subject, myfile))
74 | # creating the saliency maps only for the models that are not finetuned. (just a personal choice)
75 | if '_ft_' not in trial_name:
76 | attended_examples = x_test[y_test==1]
77 | subj_grads = np.empty(attended_examples.shape+(1,))
78 | vanilla = GradientSaliency(model)
79 | # creating a saliency map for each attended example and then averag them.
80 | for ex in range(attended_examples.shape[0]):
81 | example = attended_examples[ex]
82 | example = np.expand_dims(example, axis=2)
83 | grads = vanilla.get_smoothed_mask(example)
84 | subj_grads[ex] = grads
85 | subj_grads = np.mean(subj_grads, axis=0)
86 | all_grads[n] = subj_grads
87 | show_save_image(subj_grads, prefix='saliency_maps/'+test_subject+'_'+trial_name, actions=['save'])
88 |
89 |
90 | # computing the activation maximization maps
91 | layer_idx = -1
92 | filters = np.arange(get_num_filters(model.layers[layer_idx]))
93 | for filter_idx in filters:
94 | # changing the sigmoid activation to linear.
95 | model.layers[layer_idx].activation = activations.linear
96 | model = utils.apply_modifications(model)
97 | img = visualize_activation(model, layer_idx, filter_indices=filter_idx, input_range=(min_value,max_value))
98 | show_save_image(img, prefix='activation_maximization/'+test_subject+'_'+trial_name, actions=['save'])
99 |
100 |
101 |
102 | # saving the numerical values of the saliency maps of all subjects as it will be used later to decide the most significant
103 | #channels
104 |
105 | np.save(open('plots/saliency_maps/all_grads_eegnet.npy', 'w'), all_grads)
106 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Mon Mar 26 13:57:10 2018
5 |
6 | @author: amr
7 | """
8 | from keras.models import Sequential, Model
9 | from keras.layers import Dense, LSTM, GRU, Input
10 | from keras.layers import Conv1D, Conv2D, MaxPooling2D, AveragePooling2D, AveragePooling1D, GlobalAveragePooling2D
11 | from keras.layers import Activation, Dropout, Flatten, Dense, Reshape, Merge, SeparableConv2D
12 | from keras.layers import PReLU
13 | from keras.layers.merge import average, concatenate, add
14 | from keras.layers.normalization import BatchNormalization
15 | from keras.layers.wrappers import TimeDistributed, Bidirectional
16 | from keras.layers.local import LocallyConnected1D
17 | from keras.optimizers import Adam, RMSprop, SGD
18 | from keras.callbacks import Callback, TensorBoard, EarlyStopping, LearningRateScheduler, ModelCheckpoint, ReduceLROnPlateau
19 | from keras.constraints import maxnorm
20 | from keras.utils import plot_model
21 | from keras import regularizers
22 | from keras.regularizers import l2, l1
23 | import keras.backend as K
24 |
25 | def branched2(data_shape, model_config={'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}, f=1):
26 | timepoints = data_shape[1]
27 | channels = data_shape[2]
28 | reg = 0.01
29 | input_data = Input(shape=(timepoints, channels, 1))
30 |
31 |
32 | spatial_conv = Conv2D(12, (1,channels), padding='valid', kernel_regularizer=l2(reg))(input_data)
33 | if model_config['bn']:
34 | spatial_conv = BatchNormalization()(spatial_conv)
35 | spatial_conv = Activation(model_config['nonlinear'])(spatial_conv)
36 | if model_config['dropout']:
37 | spatial_conv = Dropout(0.5)(spatial_conv)
38 |
39 | if model_config['branched']:
40 | branch1 = Conv2D(4, (21*f,1), padding='valid', kernel_regularizer=l2(reg))(spatial_conv)
41 | if model_config['bn']:
42 | branch1 = BatchNormalization()(branch1)
43 | branch1 = Activation(model_config['nonlinear'])(branch1)
44 | branch1 = MaxPooling2D(pool_size=(3,1), strides=(3,1))(branch1)
45 | if model_config['dropout']:
46 | branch1 = Dropout(0.5)(branch1)
47 |
48 |
49 | branch1 = Flatten()(branch1)
50 |
51 | branch2 = Conv2D(4, (5*f,1), padding='valid', dilation_rate=(1,1), kernel_regularizer=l2(reg))(spatial_conv)
52 | if model_config['bn']:
53 | branch2 = BatchNormalization()(branch2)
54 | branch2 = Activation(model_config['nonlinear'])(branch2)
55 | branch2 = MaxPooling2D(pool_size=(3,1), strides=(3,1))(branch2)
56 | if model_config['dropout']:
57 | branch2 = Dropout(0.5)(branch2)
58 |
59 | if model_config['deep']:
60 | branch2 = Conv2D(8, (5*f,1), padding='valid', dilation_rate=(1,1), kernel_regularizer=l2(reg))(branch2)
61 | if model_config['bn']:
62 | branch2 = BatchNormalization()(branch2)
63 | branch2 = Activation(model_config['nonlinear'])(branch2)
64 | if model_config['dropout']:
65 | branch2 = Dropout(0.5)(branch2)
66 | #
67 | branch2 = Conv2D(8, (5*f,1), padding='valid', kernel_regularizer=l2(reg))(branch2)
68 | if model_config['bn']:
69 | branch2 = BatchNormalization()(branch2)
70 | branch2 = Activation(model_config['nonlinear'])(branch2)
71 | branch2 = MaxPooling2D(pool_size=(2,1), strides=(2,1))(branch2)
72 | if model_config['dropout']:
73 | branch2 = Dropout(0.5)(branch2)
74 | #
75 | branch2 = Flatten()(branch2)
76 |
77 | if model_config['branched']:
78 | merged = concatenate([branch1, branch2])
79 | dense = Dense(1, activation='sigmoid', kernel_regularizer=l2(reg))(merged)
80 | else:
81 | dense = Dense(1, activation='sigmoid', kernel_regularizer=l2(reg))(branch2)
82 |
83 | model = Model(inputs = [input_data], outputs=[dense])
84 |
85 |
86 | return model
87 |
88 |
89 |
90 | def create_cnn(data_shape, f=1):
91 |
92 | timepoints = data_shape[1]
93 | channels = data_shape[2]
94 | kernel_size = 3*f
95 | model = Sequential()
96 |
97 | model.add(Conv2D(4, (kernel_size,3), activation='tanh', input_shape=(timepoints, channels, 1)))
98 | model.add(BatchNormalization())
99 | model.add(MaxPooling2D((2,2)))
100 | model.add(Dropout(0.25))
101 |
102 |
103 | model.add(Conv2D(8, (kernel_size,3), strides=(1,1), activation='tanh', padding='valid'))
104 | model.add(BatchNormalization())
105 | model.add(MaxPooling2D((2,2)))
106 | # model.add(Dropout(0.25))
107 |
108 | model.add(Conv2D(16, (kernel_size,3), strides=(1,1), activation='tanh', padding='valid'))
109 | model.add(BatchNormalization())
110 | model.add(MaxPooling2D((2,2)))
111 | # model.add(Dropout(0.25))
112 |
113 | model.add(Flatten())
114 |
115 | model.add(Dense(1, activation='sigmoid'))
116 |
117 | return model
118 |
119 | def create_eegnet(data_shape, f=1):
120 |
121 | timepoints = data_shape[1]
122 | channels = data_shape[2]
123 |
124 | spatial_filters = 16
125 | model = Sequential()
126 |
127 | model.add(Conv2D(spatial_filters, (1,channels), activation='relu',
128 | kernel_regularizer=regularizers.l1_l2(0.0001), input_shape=(timepoints, channels, 1)))
129 | model.add(BatchNormalization())
130 | model.add(Dropout(0.25))
131 | model.add(Reshape((timepoints,spatial_filters,1)))
132 |
133 | model.add(Conv2D(4, (16*f,2), strides=(1,1), activation='relu', padding='same'))
134 | model.add(BatchNormalization())
135 | model.add(MaxPooling2D((4,2)))
136 | model.add(Dropout(0.25))
137 |
138 | model.add(Conv2D(4, (2*f,8), strides=(1,1), activation='relu', padding='same'))
139 | model.add(BatchNormalization())
140 | model.add(MaxPooling2D((4,2)))
141 | model.add(Dropout(0.25))
142 |
143 | model.add(Flatten())
144 | model.add(Dense(1, activation='sigmoid'))
145 | return model
146 |
--------------------------------------------------------------------------------
/intersubjective_training.py:
--------------------------------------------------------------------------------
1 | from src import collect_data_intersubjective, resample_transform, intersubjective_training, save_resutls, intersubjective_shallow, finetune
2 | import os
3 | from os.path import isfile, join
4 | from keras.models import load_model, clone_model
5 |
6 | from keras.optimizers import SGD
7 | import sys
8 | #import csv
9 | import pdb
10 | #%%
11 |
12 |
13 | subjects = [name for name in os.listdir("./data/50/subjects/")]
14 | #test_subjects = ['ab82']
15 | test_subjects = [sys.argv[1]]
16 |
17 | batch_size = 64
18 | lr = 0.001
19 | early_stopping = True
20 | epochs = 500
21 | patience = 20
22 |
23 | model_config={'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}
24 |
25 | ft = False
26 | ft_mode = 'all'
27 | ft_trials = [10, 20, 30, 40, 50, 60, 70]
28 | #
29 |
30 |
31 | datasets = ['50_avg', '250']
32 |
33 |
34 | for dataset in datasets:
35 | model_names = ['deep_intersubjective_branched_'+str(dataset)+'_thesis1',
36 | 'deep_intersubjective_eegnet_'+str(dataset)+'_thesis1',
37 | 'deep_intersubjective_cnn_'+str(dataset)+'_thesis1',
38 | 'lda_intersubjective_shrinkage_'+str(dataset)+'_thesis1',
39 | 'lda_intersubjective_'+str(dataset)+'_thesis1',
40 | 'deep_intersubjective_branched_no_bn_'+str(dataset)+'_thesis1',
41 | 'deep_intersubjective_branched_no_dropout_'+str(dataset)+'_thesis1',
42 | 'deep_intersubjective_branched_no_branched_'+str(dataset)+'_thesis1',
43 | 'deep_intersubjective_branched_no_deep_'+str(dataset)+'_thesis1',
44 | 'deep_intersubjective_branched_relu_'+str(dataset)+'_thesis1',
45 | 'deep_intersubjective_branched_elu_'+str(dataset)+'_thesis1']
46 |
47 | model_configs = [{'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
48 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
49 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
50 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
51 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
52 | {'bn':False, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
53 | {'bn':True, 'dropout':False, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
54 | {'bn':True, 'dropout':True, 'branched':False, 'deep':True, 'nonlinear':'tanh'},
55 | {'bn':True, 'dropout':True, 'branched':True, 'deep':False, 'nonlinear':'tanh'},
56 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'relu'},
57 | {'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'elu'}]
58 |
59 |
60 | for model_name, model_config in zip(model_names, model_configs):
61 |
62 |
63 | for test_subject in test_subjects:
64 | print 'working on subject', test_subject
65 | #test_subject = subject
66 | #collecting the data
67 | print 'Collecting data ...'
68 | x_train, y_train, x_test, y_test, o_t_test, o_tr_test = collect_data_intersubjective(subjects,
69 | test_subject,
70 | mode='eeg',
71 | channels=range(29),
72 | frequency=dataset)
73 |
74 |
75 | #Train
76 | print "Training ..."
77 | if model_name.startswith('deep'):
78 | metrics, history, cnf_matrix = intersubjective_training((x_train, y_train, x_test, y_test, o_t_test, o_tr_test),
79 | model_name,
80 | test_subject,
81 | epochs=epochs,
82 | lr=lr,
83 | batch_size=batch_size,
84 | model_config=model_config,
85 | early_stopping=early_stopping,
86 | patience=patience)
87 | super_final_results = save_resutls([metrics], [history], test_subject,
88 | suffix=model_name,
89 | early_stopping=early_stopping,
90 | patience=patience)
91 | else:
92 | metrics, history, cnf_matrix, clf = intersubjective_shallow((x_train, y_train, x_test, y_test, o_t_test, o_tr_test),
93 | model_name)
94 | super_final_results = save_resutls(metrics, history, test_subject, suffix=model_name, clf=clf)
95 |
96 | print 'DA:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_recognition_acc']['mean']
97 | print 'BA:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_balanced_acc']['mean']
98 | print 'Recall:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_recalls']['mean']
99 | print 'precision:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_precisions']['mean']
100 |
101 | if model_name.startswith('deep_intersubjective_branched') and ft:
102 | model=clone_model(history.model)
103 | weights = history.model.get_weights()
104 |
105 | for i in ft_trials:
106 | model_name_modified = 'deep_intersubjective_branched_ft_'+str(i)+'_trials_'+str(dataset)+'_thesis2'
107 | model_name_modified = model_name+'_ft_'+str(i)+'_trials'
108 | model.set_weights(weights)
109 | metrics, history, cnf_matrix = finetune(model,
110 | (x_test, y_test, o_t_test, o_tr_test),
111 | model_name_modified,
112 | test_subject,
113 | epochs=epochs,
114 | train_trials=i,
115 | mode=ft_mode,
116 | early_stopping=early_stopping,
117 | patience=patience)
118 |
119 | super_final_results = save_resutls([metrics], [history], test_subject,
120 | suffix=model_name_modified,
121 | early_stopping=early_stopping,
122 | patience=patience)
123 | print 'DA:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_recognition_acc']['mean']
124 | print 'BA:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_balanced_acc']['mean']
125 | print 'Recall:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_recalls']['mean']
126 | print 'precision:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_precisions']['mean']
127 |
--------------------------------------------------------------------------------
/src.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import numpy as np
3 | import tensorflow as tf
4 | import random as rn
5 | np.random.seed(42)
6 | rn.seed(12345)
7 | tf.set_random_seed(1234)
8 | from models import *
9 | import matplotlib.pyplot as plt
10 | from imblearn.under_sampling import RandomUnderSampler
11 | from imblearn.over_sampling import RandomOverSampler
12 | from sklearn.model_selection import train_test_split, LeavePGroupsOut, GroupKFold, StratifiedKFold
13 | from sklearn.preprocessing import MinMaxScaler, StandardScaler
14 | from sklearn.utils.class_weight import compute_class_weight
15 | from sklearn.model_selection import StratifiedKFold, train_test_split, StratifiedShuffleSplit
16 | from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix, accuracy_score
17 | import itertools
18 | from utils import plot_confusion_matrix
19 | import pdb
20 | import os
21 | from os import listdir
22 | from os.path import isfile, join
23 | import time
24 | import math
25 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
26 | from sklearn import svm
27 | from sklearn.externals import joblib
28 | from keras.models import load_model
29 |
30 |
31 | def load_data(subject, mode='eeg', channels=[], frequency=50):
32 | data_path = 'data/'+str(frequency)+'/subjects/'
33 | if mode=='eeg':
34 | x = np.load(open(data_path+subject+'/segments.npy'))
35 | elif mode=='meg':
36 | x = np.load(open(data_path+subject+'/meg_segments.npy'))
37 | else:
38 | print 'Wrong mode. you can only choose eeg or meg'
39 | y = np.load(open(data_path+subject+'/labels.npy'))
40 | t = np.load(open(data_path+subject+'/target.npy'))
41 | tr = np.load(open(data_path+subject+'/triggers.npy'))
42 | temp_t = []
43 | for i in t:
44 | trial_index = np.ones((60,))*i
45 | temp_t.extend(np.ones((60,))*i)
46 | t = np.array(temp_t).astype(int)
47 |
48 | if len(channels):
49 | x = x[:,channels,:]
50 |
51 | # indexes = np.random.permutation(len(tr))
52 | # tr = tr[indexes]
53 | return x, y, t, tr
54 | def preprocessing(x, frequency='50_avg'):
55 | x = np.swapaxes(x,1,2)
56 | ## baseline correction
57 | if frequency == '50_avg':
58 | bl = 5
59 | elif frequency=='250':
60 | bl = 25
61 | # bl = int(frequency/10)
62 | corrected = np.empty_like(x)
63 | for i in range(x.shape[0]):
64 | baselines = np.mean(x[i,0:bl,:], axis=0)
65 | corrected[i] = x[i] - baselines
66 | x = corrected
67 | return x
68 |
69 | def precision(y_true, y_pred):
70 | return precision_score(y_true, y_pred)
71 | def recall(y_true, y_pred):
72 | return recall_score(y_true, y_pred)
73 | def f1(y_true, y_pred):
74 | return f1_score(y_true, y_pred)
75 | def auc(y_true, y_pred):
76 | return roc_auc_score(y_true, y_pred)
77 | def balanced_accuracy(y_true, y_pred):
78 | tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
79 | recall_p = float(tp) / (tp + fn)
80 | recall_n = float(tn) / (tn + fp)
81 | return (recall_p + recall_n) /2.0
82 |
83 | #def recognition_accuracy(probs):
84 | # trials = int(len(probs)/60)
85 | # n_correct = 0
86 | # matrix = []
87 | # for i in range(trials):
88 | # classes = np.zeros((12,))
89 | # for k in range(i*60,i*60+60):
90 | #
91 | # classes[tr_test[k]-1] += probs[k]
92 | # choosen = np.argmax(classes)+1
93 | # if choosen == t_test[i*60]:
94 | # n_correct+=1
95 | # confidence = classes/float(np.sum(classes))
96 | # row = np.array([confidence, t_test[i*60]])
97 | # matrix.append(row)
98 | # return np.array([n_correct/float(trials), np.array(matrix)])
99 | def recognition_accuracy(probs):
100 | # pdb.set_trace()
101 | trials = int(len(probs)/60)
102 | n_correct = np.zeros((5,))
103 | matrix = []
104 | for i in range(trials):
105 | classes = [[] for s in range(12)]
106 | for k in range(i*60,i*60+60):
107 |
108 | classes[tr_test[k]-1].append(probs[k])
109 | classes = np.array(classes)
110 | # pdb.set_trace()
111 | if classes.ndim > 1:
112 | classes = np.cumsum(classes, axis=1)
113 | choosen = np.argmax(classes, axis=0)+1
114 | choices = []
115 | for j in range(5):
116 | if choosen[j] == t_test[i*60]:
117 | n_correct[j]+=1
118 | choices.append((t_test[i*60], choosen[j]))
119 | # confidence = classes/float(np.sum(classes))
120 | # row = np.array([confidence, t_test[i*60]])
121 | # matrix.append(row)
122 | matrix.append(choices)
123 | return np.array([n_correct/float(trials), np.array(matrix)])
124 |
125 | def bitperminute(p, n, t):
126 | p[p==1] = 1-np.finfo(float).eps
127 | B = np.log2(n)+p*np.log2(p)+(1-p)*np.log2((1-p)/(n-1).astype(float))
128 | return B*(float(60)/t)
129 |
130 | class Metrics(Callback):
131 | def on_train_begin(self, logs={}):
132 | self.val_precisions = []
133 | self.val_recalls = []
134 | self.val_f1s = []
135 | self.val_aucs = []
136 | self.val_balanced_acc = []
137 | self.val_recognition_acc = []
138 | self.val_bpm = []
139 | self.test_acc = []
140 | self.test_loss = []
141 | def on_epoch_end(self, epoch, logs={}):
142 | # probs = self.model.predict(self.validation_data[0])
143 | # y_pred = np.round(probs)
144 | # y_true = self.validation_data[1]
145 | probs = self.model.predict(x_test)
146 | probs = probs.ravel()
147 | y_pred = np.round(probs)
148 | y_true = y_test
149 | self.test_acc.append(accuracy_score(y_true, y_pred))
150 | self.test_loss.append(self.model.evaluate(x_test, y_test, batch_size = 64, verbose=0)[0])
151 | self.val_precisions.append(precision(y_true, y_pred))
152 | self.val_recalls.append(recall(y_true, y_pred))
153 | self.val_f1s.append(f1(y_true, y_pred))
154 | self.val_aucs.append(auc(y_true, probs))
155 | self.val_balanced_acc.append(balanced_accuracy(y_true, y_pred))
156 | self.val_recognition_acc.append(recognition_accuracy(probs))
157 | self.val_bpm.append(bitperminute(self.val_recognition_acc[-1][0], np.ones((5,))*12,np.arange(2,12,2)))
158 | return
159 |
160 |
161 | def step_decay(epoch):
162 | initial_lrate = 0.001
163 | drop = 0.5
164 | epochs_drop = 10.0
165 | lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
166 | # if epoch:
167 | # lrate = initial_lrate/np.sqrt(epoch)
168 | # else:
169 | # return initial_lrate
170 | return lrate
171 |
172 | def cv_splitter(x, n_splits=5):
173 |
174 | n_segments = x.shape[0]
175 | n_trials = int(n_segments/60)
176 | groups = []
177 | for i in range(n_trials):
178 | trial_index = np.ones((60,))*i
179 | groups.extend(trial_index)
180 | groups = np.array(groups)
181 |
182 | window = int(np.round(n_trials/float(n_splits)))
183 |
184 | intervals=[]
185 | for i in range(0,n_trials,window):
186 | intervals.append(range(i,np.minimum(i+window, n_trials)))
187 |
188 | if len(intervals[-1]) 4:
544 | x_train, y_train, x_test, y_test, x_valid, y_valid = data
545 | else:
546 | x_train, y_train, x_test, y_test = data
547 |
548 |
549 | # standarization of the data
550 | # computing the mean and std on the training data
551 | # scalar = StandardScaler(with_mean=False)
552 | # stds = []
553 | # trials_no = x_train.shape[0]
554 | # for i in range(trials_no):
555 | # scalar.fit(x_train[i])
556 | # std = scalar.scale_
557 | # stds.append(std)
558 | #
559 | # scalar.scale_ = np.mean(stds, axis=0)
560 | # normalized_x_train = np.empty_like(x_train)
561 | # for i in range(trials_no):
562 | # temp = scalar.transform(x_train[i])
563 | # normalized_x_train[i] = temp
564 | #
565 | # # transforming the test data
566 | # normalized_x_test = np.empty_like(x_test)
567 | # trials_no = x_test.shape[0]
568 | # for i in range(trials_no):
569 | # temp = scalar.transform(x_test[i])
570 | # normalized_x_test[i] = temp
571 | #
572 | # x_train = normalized_x_train
573 | # x_test = normalized_x_test
574 |
575 |
576 | scalar = StandardScaler(with_mean=True)
577 | scalar.fit(x_train.reshape(x_train.shape[0],-1))
578 | normalized_x_train = scalar.transform(x_train.reshape(x_train.shape[0], -1)).reshape(x_train.shape)
579 | normalized_x_test = scalar.transform(x_test.reshape(x_test.shape[0], -1)).reshape(x_test.shape)
580 | if len(data)>4:
581 | normalized_x_valid = scalar.transform(x_valid.reshape(x_valid.shape[0], -1)).reshape(x_valid.shape)
582 | x_train = normalized_x_train
583 | x_test = normalized_x_test
584 | if len(data) > 4:
585 | x_valid = normalized_x_valid
586 |
587 | if resample:
588 | #resampling the data
589 | n_samples, timepoints, channels = x_train.shape
590 | x_train = np.reshape(x_train, (n_samples, timepoints * channels))
591 | ros = RandomOverSampler(random_state=0)
592 | x_res, y_res = ros.fit_sample(x_train, y_train)
593 | x_train = np.reshape(x_res, (x_res.shape[0], timepoints, channels))
594 | y_train = y_res
595 |
596 | x_train = np.expand_dims(x_train, axis=3)
597 | x_test = np.expand_dims(x_test, axis=3)
598 | if len(data) > 4:
599 | x_valid = np.expand_dims(x_valid, axis=3)
600 | return x_train, y_train, x_test, y_test, x_valid, y_valid
601 | return x_train, y_train, x_test, y_test
602 |
603 | def transform(data):
604 |
605 | x_train, y_train, x_test, y_test = data
606 |
607 |
608 | # standarization of the data
609 | # computing the mean and std on the training data
610 | scalar = StandardScaler(with_mean=False)
611 | stds = []
612 | trials_no = x_train.shape[0]
613 | for i in range(trials_no):
614 | scalar.fit(x_train[i])
615 | std = scalar.scale_
616 | stds.append(std)
617 |
618 | scalar.scale_ = np.mean(stds, axis=0)
619 | normalized_x_train = np.empty_like(x_train)
620 | for i in range(trials_no):
621 | temp = scalar.transform(x_train[i])
622 | normalized_x_train[i] = temp
623 |
624 | # transforming the test data
625 | normalized_x_test = np.empty_like(x_test)
626 | trials_no = x_test.shape[0]
627 | for i in range(trials_no):
628 | temp = scalar.transform(x_test[i])
629 | normalized_x_test[i] = temp
630 |
631 | x_train = normalized_x_train
632 | x_test = normalized_x_test
633 |
634 | return x_train, y_train, x_test, y_test
635 | def intersubjective_shallow(data, model_name):
636 |
637 | x_train, y_train, x_test, y_test, o_t_test, o_tr_test = data
638 |
639 | x_train, y_train, x_test, y_test = resample_transform((x_train, y_train, x_test, y_test), resample=False)
640 |
641 | global t_test
642 | t_test = o_t_test
643 | global tr_test
644 | tr_test = o_tr_test
645 | x_train = x_train.reshape(x_train.shape[0], -1)
646 | x_test = x_test.reshape(x_test.shape[0], -1)
647 | m = ['acc', 'val_acc', 'val_precisions', 'val_recalls', 'val_f1s', 'val_aucs',
648 | 'val_balanced_acc', 'val_recognition_acc', 'val_bpm']
649 | metrics = {key: [] for key in m }
650 | history = False
651 | if 'svm' in model_name:
652 | clf = svm.LinearSVC(random_state = 0)
653 | elif 'lda' in model_name:
654 | if 'shrinkage' in model_name:
655 | clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')
656 | else:
657 | clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage=None)
658 | clf.fit(x_train, y_train)
659 | y_predict = clf.predict(x_test)
660 | probs = clf.decision_function(x_test)
661 | metrics['acc'].append(clf.score(x_train, y_train))
662 | metrics = compute_metrics(metrics, probs, y_predict, y_test)
663 | cnf_matrix = confusion_matrix(y_test, y_predict)
664 |
665 | return metrics, history, cnf_matrix, clf
666 |
667 | def intersubjective_training(data,model_name, subject, epochs=5, lr=0.001,
668 | batch_size=128,
669 | model_config={'bn':True, 'dropout':True, 'branched':True, 'nonlinear':'tanh'},
670 | early_stopping=True, patience=10):
671 |
672 | global y_test
673 | x_tv, y_tv, x_test, y_test, o_t_test, o_tr_test = data
674 |
675 | if early_stopping:
676 | # pdb.set_trace()
677 | x_train, x_valid, y_train, y_valid = train_test_split(x_tv, y_tv,
678 | stratify=y_tv,
679 | test_size=0.2)
680 | global x_test
681 | x_train, y_train, x_test, y_test, x_valid, y_valid = resample_transform((x_train, y_train, x_test, y_test, x_valid, y_valid))
682 | else:
683 | x_train = x_tv
684 | y_train = y_tv
685 | global x_test
686 | x_train, y_train, x_test, y_test = resample_transform((x_train, y_train, x_test, y_test))
687 |
688 | global t_test
689 | t_test = o_t_test
690 | global tr_test
691 | tr_test = o_tr_test
692 |
693 | if 'branched' in model_name:
694 | if '250' in model_name:
695 | model = branched2(x_train.shape, model_config=model_config, f=5)
696 | else:
697 | model = branched2(x_train.shape, model_config=model_config, f=1)
698 | elif 'eegnet' in model_name:
699 | if '250' in model_name:
700 | model = create_eegnet(x_train.shape, f=4)
701 | else:
702 | model = create_eegnet(x_train.shape, f=1)
703 | elif 'cnn' in model_name:
704 | if '250' in model_name:
705 | model = create_cnn(x_train.shape, f=5)
706 | else:
707 | model = create_cnn(x_train.shape, f=1)
708 |
709 | lrate = LearningRateScheduler(step_decay)
710 | adam = Adam(lr=lr)
711 | # lrate = LearningRateScheduler(step_decay)
712 | model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])
713 |
714 | m = Metrics()
715 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
716 | patience=int(patience/2), min_lr=0)
717 | early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=patience, verbose=0, mode='auto')
718 | mod_path = './models/subjects/'+subject
719 | timestr = time.strftime("%Y%m%d-%H%M")
720 |
721 |
722 | checkpointer = ModelCheckpoint(filepath=mod_path+'/best_'+model_name+'_'+timestr,
723 | monitor='val_loss', verbose=1, save_best_only=True)
724 | if early_stopping:
725 | history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, shuffle=True, verbose=2,
726 | validation_data=(x_valid, y_valid), callbacks=[m, early_stop, checkpointer, reduce_lr])
727 | else:
728 | history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, shuffle=True, verbose=2,
729 | validation_data=(x_test, y_test), callbacks=[m])
730 |
731 |
732 |
733 |
734 | probabilities = model.predict(x_test, batch_size=batch_size, verbose=0)
735 | y_predict = [(round(k)) for k in probabilities]
736 |
737 | cnf_matrix = confusion_matrix(y_test, y_predict)
738 |
739 | return m, history, cnf_matrix
740 |
741 |
742 |
743 | def finetune(model, data, model_name, subject, epochs=10, train_trials=40, mode='all', early_stopping=True, patience=10):
744 |
745 | for layer in model.layers[:26]:
746 | if mode=='all':
747 | layer.trainable = True
748 | elif mode=='top':
749 | layer.trainable = False
750 | else:
751 | print 'wrong keyword argument'
752 | return
753 | # print model.summary()
754 |
755 | opt = SGD(lr=1e-4, momentum=0.9)
756 |
757 | model.compile(loss='binary_crossentropy', optimizer=opt , metrics=['accuracy'])
758 |
759 | x_test, y_test, o_t_test, o_tr_test = data
760 |
761 | segments = train_trials * 60
762 |
763 | x_tv = x_test[0:segments]
764 | y_tv = y_test[0:segments]
765 |
766 | if early_stopping:
767 | x_train, x_valid, y_train, y_valid = train_test_split(x_tv, y_tv,
768 | stratify=y_tv,
769 | test_size=0.2)
770 | x_test = x_test[segments:]
771 | global y_test
772 | y_test = y_test[segments:]
773 | global x_test
774 | x_train, y_train, x_test, y_test, x_valid, y_valid = resample_transform((x_train, y_train, x_test, y_test, x_valid, y_valid))
775 | else:
776 | x_train = x_tv
777 | y_train = y_tv
778 | x_test = x_test[segments:]
779 | global y_test
780 | y_test = y_test[segments:]
781 | global x_test
782 | x_train, y_train, x_test, y_test = resample_transform((x_train, y_train, x_test, y_test))
783 |
784 | # x_test = x_test[segments:]
785 | # y_test = y_test[segments:]
786 |
787 |
788 |
789 | global t_test
790 | t_test = o_t_test[segments:]
791 | global tr_test
792 | tr_test = o_tr_test[segments:]
793 | # #resampling the data
794 | # n_samples, timepoints, channels, z = x_train.shape
795 | # x_train = np.reshape(x_train, (n_samples, timepoints * channels))
796 | # ros = RandomOverSampler(random_state=0)
797 | # x_res, y_res = ros.fit_sample(x_train, y_train)
798 | # x_train = np.reshape(x_res, (x_res.shape[0], timepoints, channels))
799 | # y_train = y_res
800 | #
801 | # x_train = np.expand_dims(x_train, axis=3)
802 | # x_test = np.expand_dims(x_test, axis=3)
803 |
804 | early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=patience, verbose=0, mode='auto')
805 |
806 |
807 | m = Metrics()
808 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
809 | patience=int(patience/2), min_lr=0)
810 | mod_path = './models/subjects/'+subject
811 | timestr = time.strftime("%Y%m%d-%H%M")
812 |
813 | checkpointer = ModelCheckpoint(filepath=mod_path+'/best_'+model_name+'_'+timestr,
814 | monitor='val_loss', verbose=1, save_best_only=True)
815 | if early_stopping:
816 | history = model.fit(x_train, y_train, batch_size=64, epochs=epochs, shuffle=True, verbose=2,
817 | validation_data=(x_valid, y_valid), callbacks=[m, early_stop, checkpointer, reduce_lr])
818 | else:
819 | history = model.fit(x_train, y_train, batch_size=64, epochs=epochs, shuffle=True, verbose=2,
820 | validation_data=(x_test, y_test), callbacks=[m])
821 |
822 |
823 | probabilities = model.predict(x_test, batch_size=128, verbose=0)
824 | y_predict = [(round(k)) for k in probabilities]
825 |
826 | cnf_matrix = confusion_matrix(y_test, y_predict)
827 |
828 |
829 | return m, history, cnf_matrix
830 |
--------------------------------------------------------------------------------