├── .DS_Store
├── .gitignore
├── .idea
├── codeStyles
│ └── codeStyleConfig.xml
├── dictionaries
│ ├── dave.xml
│ └── davidemiani.xml
├── encodings.xml
├── hgdecode.iml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── other.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── __init__.py
├── dl_main.py
├── dl_main_cross_subject.py
├── hgdecode
├── classes.py
├── experiments.py
├── fbcsprlda.py
├── lda.py
├── loaders.py
├── models.py
├── signalproc.py
└── utils.py
├── ml_main.py
├── ml_main_cross_subject.py
├── schirrmeister_main.py
├── sub_routines
├── dl_cross_validation.py
├── latex_tabular_parser.py
├── latex_tabular_parser_cross_subj.py
├── latex_tabular_parser_transfer_learning.py
├── latex_tabular_parser_transfer_learning_frozen_layers.py
├── learning_curve.py
├── ml_cross_validation.py
├── t_test.py
├── transfer_learning_curve.py
└── transfer_learning_curve_frozen_layers.py
└── transfer_learning.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidemiani/hgdecode/359715c35035b7a37c0767e221a0ffa73254c7be/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | *.pyc
3 |
4 | \.idea/codeStyles/codeStyleConfig\.xml
5 |
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/dictionaries/dave.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | accs
5 | arxiv
6 | asctime
7 | bandpassed
8 | bandpassing
9 | bbci
10 | bestfilterband
11 | bincsp
12 | butterworth
13 | conv
14 | convolutional
15 | crossentropy
16 | datasets
17 | depthwise
18 | dmdl
19 | eigenvectors
20 | epoched
21 | epoching
22 | fbcsp
23 | filt
24 | filterband
25 | filterbands
26 | filterbank
27 | filterpair
28 | filterwise
29 | hgdecode
30 | hyperparameter
31 | hyperparameters
32 | inds
33 | ival
34 | keras
35 | levelname
36 | miani
37 | microvolt
38 | multiclass
39 | ndarray
40 | npint
41 | overcomplete
42 | pointwise
43 | preds
44 | regularizations
45 | regularizer
46 | resample
47 | resampling
48 | robintibor
49 | schirrmeister
50 | scipy
51 | sfreq
52 | softmax
53 | ssvep
54 | waytowich
55 |
56 |
57 |
--------------------------------------------------------------------------------
/.idea/dictionaries/davidemiani.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | algo
5 | barstr
6 | dataset
7 | davidemiani
8 | filtfilt
9 | frac
10 | hline
11 | idxs
12 | ipykernel
13 | isatty
14 | magistrale
15 | numdigits
16 | nyquist
17 | occurrs
18 | perc
19 | prec
20 | pred
21 | preprocess
22 | prog
23 | pval
24 | rowcolor
25 | seaborn
26 | stdd
27 | studiorum
28 | subjs
29 | tesi
30 | textbf
31 | totale
32 | universita
33 | whitegrid
34 | xticks
35 |
36 |
37 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/hgdecode.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # hgdecode
2 | Implements High-Gamma dataset decoding using Filter Bank Common Spatial Pattern with rLDA classification and Neural Networks.
3 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidemiani/hgdecode/359715c35035b7a37c0767e221a0ffa73254c7be/__init__.py
--------------------------------------------------------------------------------
/dl_main.py:
--------------------------------------------------------------------------------
1 | from os import getcwd
2 | from os.path import join
3 | from os.path import dirname
4 | from collections import OrderedDict
5 | from numpy.random import RandomState
6 | from hgdecode.utils import create_log
7 | from hgdecode.utils import print_manager
8 | from hgdecode.loaders import dl_loader
9 | from hgdecode.classes import CrossValidation
10 | from hgdecode.experiments import DLExperiment
11 | from keras import backend as K
12 |
13 | """
14 | SETTING PARAMETERS
15 | ------------------
16 | In the following, you have to set / modify all the parameters to use for
17 | further computation.
18 |
19 | Parameters
20 | ----------
21 | channel_names : list
22 | Channels to use for computation
23 | data_dir : str
24 | Path to the directory that contains dataset
25 | model_name : str
26 | Name of the Deep Learning model
27 | name_to_start_codes : OrderedDict
28 | All possible classes names and codes in an ordered dict format
29 | random_seed : rng seed
30 | Seed random for all random calls
31 | results_dir : str
32 | Path to the directory that will contain the results
33 | subject_ids : tuple
34 | All the subject ids in a tuple; add or remove subjects to run the
35 | algorithm for them or not
36 | """
37 | # setting model_name and validation_frac
38 | model_name = 'DeepConvNet' # Schirrmeister: 'DeepConvNet' or 'ShallowNet'
39 |
40 | # setting channel_names
41 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4',
42 | 'CP5', 'CP1', 'CP2', 'CP6',
43 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6',
44 | 'CP3', 'CPz', 'CP4',
45 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
46 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h',
47 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
48 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h',
49 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
50 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h']
51 |
52 | # setting data_dir & results_dir
53 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma')
54 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode')
55 |
56 | # setting name_to_start_codes
57 | name_to_start_codes = OrderedDict([('Right Hand', [1]),
58 | ('Left Hand', [2]),
59 | ('Rest', [3]),
60 | ('Feet', [4])])
61 |
62 | # setting random_state
63 | random_state = RandomState(1234)
64 |
65 | # real useful hyperparameters
66 | standardize_mode = 2
67 | subject_ids = (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)
68 | ival = (-500, 4000)
69 | n_folds = 12
70 | fold_size = None
71 | swap_train_test = False
72 | learning_rate = 1e-4
73 | dropout_rate = 0.5
74 | batch_size = 64
75 | epochs = 1000
76 |
77 | """
78 | MAIN CYCLE
79 | ----------
80 | For each subject, a new log will be created and the specific dataset loaded;
81 | this dataset will be used to create an instance of the experiment; then the
82 | experiment will be run. You can of course change all the experiment inputs
83 | to obtain different results.
84 | """
85 | for subject_id in subject_ids:
86 | # creating a log object
87 | subj_results_dir = create_log(
88 | results_dir=results_dir,
89 | learning_type='dl',
90 | algorithm_or_model_name=model_name,
91 | subject_id=subject_id,
92 | output_on_file=False
93 | )
94 |
95 | # loading epoched signal
96 | epo = dl_loader(
97 | data_dir=data_dir,
98 | name_to_start_codes=name_to_start_codes,
99 | channel_names=channel_names,
100 | subject_id=subject_id,
101 | resampling_freq=250, # Schirrmeister: 250
102 | clean_ival_ms=ival, # Schirrmeister: (0, 4000)
103 | epoch_ival_ms=ival, # Schirrmeister: (-500, 4000)
104 | train_test_split=True, # Schirrmeister: True
105 | clean_on_all_channels=False, # Schirrmeister: True
106 | standardize_mode=standardize_mode # Schirrmeister: 2
107 | )
108 |
109 | # creating CrossValidation class instance
110 | cv = CrossValidation(
111 | X=epo.X,
112 | y=epo.y,
113 | n_folds=n_folds,
114 | fold_size=fold_size,
115 | validation_frac=0.1,
116 | random_state=random_state,
117 | shuffle=True,
118 | swap_train_test=swap_train_test
119 | )
120 | if n_folds is None:
121 | cv.balance_train_set(train_size=fold_size)
122 |
123 | # pre-allocating experiment
124 | exp = None
125 |
126 | # cycling on folds for cross validation
127 | for fold_idx, current_fold in enumerate(cv.folds):
128 | # clearing TF graph (https://github.com/keras-team/keras/issues/3579)
129 | print_manager('CLEARING KERAS BACKEND', print_style='double-dashed')
130 | K.clear_session()
131 | print_manager(print_style='last', bottom_return=1)
132 |
133 | # printing fold information
134 | print_manager(
135 | 'SUBJECT {}, FOLD {}'.format(subject_id, fold_idx + 1),
136 | print_style='double-dashed'
137 | )
138 | cv.print_fold_classes(fold_idx)
139 | print_manager(print_style='last', bottom_return=1)
140 |
141 | # creating EEGDataset for current fold
142 | dataset = cv.create_dataset(fold=current_fold)
143 |
144 | # creating experiment instance
145 | exp = DLExperiment(
146 | # non-default inputs
147 | dataset=dataset,
148 | model_name=model_name,
149 | results_dir=results_dir,
150 | subj_results_dir=subj_results_dir,
151 | name_to_start_codes=name_to_start_codes,
152 | random_state=random_state,
153 | fold_idx=fold_idx,
154 |
155 | # hyperparameters
156 | dropout_rate=dropout_rate, # Schirrmeister: 0.5
157 | learning_rate=learning_rate, # Schirrmeister: ?
158 | batch_size=batch_size, # Schirrmeister: 512
159 | epochs=epochs, # Schirrmeister: ?
160 | early_stopping=False, # Schirrmeister: ?
161 | monitor='val_acc', # Schirrmeister: ?
162 | min_delta=0.0001, # Schirrmeister: ?
163 | patience=5, # Schirrmeister: ?
164 | loss='categorical_crossentropy', # Schirrmeister: ad hoc
165 | optimizer='Adam', # Schirrmeister: Adam
166 | shuffle=True, # Schirrmeister: ?
167 | crop_sample_size=None, # Schirrmeister: 1125
168 | crop_step=None, # Schirrmeister: 1
169 |
170 | # other parameters
171 | subject_id=subject_id,
172 | data_generator=False, # Schirrmeister: True
173 | save_model_at_each_epoch=False
174 | )
175 |
176 | # training
177 | exp.train()
178 |
179 | if exp is not None:
180 | # computing cross-validation
181 | cv.cross_validate(subj_results_dir=subj_results_dir,
182 | label_names=name_to_start_codes)
183 |
--------------------------------------------------------------------------------
/dl_main_cross_subject.py:
--------------------------------------------------------------------------------
1 | from os import getcwd
2 | from os.path import join
3 | from os.path import dirname
4 | from collections import OrderedDict
5 | from numpy.random import RandomState
6 | from hgdecode.utils import create_log
7 | from hgdecode.utils import print_manager
8 | from hgdecode.loaders import CrossSubject
9 | from hgdecode.classes import CrossValidation
10 | from hgdecode.experiments import DLExperiment
11 | from keras import backend as K
12 |
13 | """
14 | SETTING PARAMETERS
15 | Here you can set whatever parameter you want
16 | """
17 | # setting model_name
18 | model_name = 'DeepConvNet'
19 |
20 | # setting channel_names
21 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4',
22 | 'CP5', 'CP1', 'CP2', 'CP6',
23 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6',
24 | 'CP3', 'CPz', 'CP4',
25 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
26 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h',
27 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
28 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h',
29 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
30 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h']
31 |
32 | # setting data_dir & results_dir
33 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma')
34 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode')
35 |
36 | # setting name_to_start_codes
37 | name_to_start_codes = OrderedDict([('Right Hand', [1]),
38 | ('Left Hand', [2]),
39 | ('Rest', [3]),
40 | ('Feet', [4])])
41 |
42 | # setting random_state
43 | random_state = RandomState(1234)
44 |
45 | # setting subject_ids
46 | subject_ids = (1, 2) # , 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)
47 |
48 | # setting hyperparameters
49 | ival = (-500, 4000)
50 | standardize_mode = 2
51 | learning_rate = 1 * 1e-4
52 | dropout_rate = 0.5
53 | batch_size = 32
54 | epochs = 800
55 |
56 | """
57 | STARTING LOADING ROUTINE & COMPUTATION
58 | Here you can change some parameter in function calls as well
59 | """
60 | # creating a log object
61 | subj_results_dir = create_log(
62 | results_dir=results_dir,
63 | learning_type='dl',
64 | algorithm_or_model_name=model_name,
65 | subject_id='subj_cross',
66 | output_on_file=False
67 | )
68 |
69 | # creating a cross-subject object for cross-subject validation
70 | cross_obj = CrossSubject(data_dir=data_dir,
71 | subject_ids=subject_ids,
72 | channel_names=channel_names,
73 | name_to_start_codes=name_to_start_codes,
74 | random_state=random_state,
75 | validation_frac=0.1,
76 | resampling_freq=250,
77 | train_test_split=True,
78 | clean_ival_ms=ival,
79 | epoch_ival_ms=ival,
80 | clean_on_all_channels=False)
81 |
82 | # parsing all cnt data to epoched (we no more need cnt)
83 | cross_obj.parser(output_format='epo', parsing_type=1)
84 |
85 | # pre-allocating experiment
86 | exp = None
87 |
88 | # cycling on subject leaved apart
89 | for leave_subj in subject_ids:
90 | # clearing TF graph (https://github.com/keras-team/keras/issues/3579)
91 | print_manager('CLEARING KERAS BACKEND', print_style='double-dashed')
92 | K.clear_session()
93 | print_manager(print_style='last', bottom_return=1)
94 |
95 | # creating dataset for this "all but" fold
96 | cross_obj.parser(output_format='EEGDataset',
97 | leave_subj=leave_subj,
98 | parsing_type=1)
99 |
100 | # creating experiment instance
101 | exp = DLExperiment(
102 | # non-default inputs
103 | dataset=cross_obj.fold_data,
104 | model_name=model_name,
105 | results_dir=results_dir,
106 | name_to_start_codes=name_to_start_codes,
107 | random_state=random_state,
108 | fold_idx=leave_subj - 1,
109 |
110 | # hyperparameters
111 | dropout_rate=dropout_rate, # Schirrmeister: 0.5
112 | learning_rate=learning_rate, # Schirrmeister: ?
113 | batch_size=batch_size, # Schirrmeister: 512
114 | epochs=epochs, # Schirrmeister: ?
115 | early_stopping=False, # Schirrmeister: ?
116 | monitor='val_acc', # Schirrmeister: ?
117 | min_delta=0.0001, # Schirrmeister: ?
118 | patience=5, # Schirrmeister: ?
119 | loss='categorical_crossentropy', # Schirrmeister: ad hoc
120 | optimizer='Adam', # Schirrmeister: Adam
121 | shuffle=True, # Schirrmeister: ?
122 | crop_sample_size=None, # Schirrmeister: 1125
123 | crop_step=None, # Schirrmeister: 1
124 |
125 | # other parameters
126 | subject_id='_cross',
127 | data_generator=False, # Schirrmeister: True
128 | save_model_at_each_epoch=False,
129 | subj_results_dir=subj_results_dir
130 | )
131 |
132 | # running training
133 | exp.train()
134 |
135 | # at the very end, running cross-validation
136 | if exp is not None:
137 | CrossValidation.cross_validate(subj_results_dir=exp.subj_results_dir,
138 | label_names=name_to_start_codes)
139 |
--------------------------------------------------------------------------------
/hgdecode/experiments.py:
--------------------------------------------------------------------------------
1 | # General modules
2 | import numpy as np
3 | from numpy import arange
4 | from numpy import setdiff1d
5 | from numpy import int as npint
6 | from pickle import load
7 | from os.path import join
8 | from os.path import dirname
9 | from itertools import combinations
10 | from hgdecode.utils import touch_dir
11 | from hgdecode.utils import my_formatter
12 | from hgdecode.utils import print_manager
13 | from sklearn.metrics import confusion_matrix
14 | from multiprocessing import cpu_count
15 |
16 | # Deep Learning
17 | from hgdecode import models
18 | from keras import optimizers
19 | from keras.callbacks import CSVLogger
20 | from keras.callbacks import EarlyStopping
21 | from keras.callbacks import ModelCheckpoint
22 | from hgdecode.classes import MetricsTracker
23 | from hgdecode.classes import EEGDataGenerator
24 |
25 | # Machine Learning
26 | from hgdecode.classes import FilterBank
27 | from hgdecode.fbcsprlda import BinaryFBCSP
28 | from hgdecode.fbcsprlda import FBCSP
29 | from hgdecode.fbcsprlda import MultiClassWeightedVoting
30 | from braindecode.datautil.iterators import get_balanced_batches
31 |
32 |
33 | class FBCSPrLDAExperiment(object):
34 | """
35 | A Filter Bank Common Spatial Patterns with rLDA
36 | classification Experiment.
37 |
38 | Parameters
39 | ----------
40 | cnt : RawArray
41 | The continuous train recordings with events in info['events']
42 | clean_trial_mask : bool array
43 | Bool array containing information about valid/invalid trials
44 | name_to_start_codes: dict
45 | Dictionary mapping class names to marker numbers, e.g.
46 | {'1 - Correct': [31], '2 - Error': [32]}
47 | epoch_ival_ms : sequence of 2 floats
48 | The start and end of the trial in milliseconds with respect to
49 | the markers.
50 | min_freq : int or list or tuple
51 | The minimum frequency/ies of the filterbank/s.
52 | max_freq : int or list or tuple
53 | The maximum frequency/ies of the filterbank/s.
54 | window : int or list or tuple
55 | Bandwidths of filters in filterbank/s.
56 | overlap : int or list or tuple
57 | Overlap frequencies between filters in filterbank/s.
58 | filt_order : int
59 | The filter order of the butterworth filter which computes the
60 | filterbands.
61 | n_folds : int
62 | How many folds. Also determines size of the test fold, e.g.
63 | 5 folds imply the test fold has 20% of the original data.
64 | n_top_bottom_csp_filters : int or None
65 | Number of top and bottom CSP filters to select from all computed
66 | filters. Top and bottom refers to CSP filters sorted by their
67 | eigenvalues. So a value of 3 here will lead to 6(!) filters.
68 | None means all filters.
69 | n_selected_filterbands : int or None
70 | Number of filterbands to select for the filterbank.
71 | Will be selected by the highest training accuracies.
72 | None means all filterbands.
73 | n_selected_features : int or None
74 | Number of features to select for the filterbank.
75 | Will be selected by an internal cross validation across feature
76 | subsets.
77 | None means all features.
78 | forward_steps : int
79 | Number of forward steps to make in the feature selection,
80 | before the next backward step.
81 | backward_steps : int
82 | Number of backward steps to make in the feature selection,
83 | before the next forward step.
84 | stop_when_no_improvement: bool
85 | Whether to stop the feature selection if the internal cross
86 | validation accuracy could not be improved after an epoch finished
87 | (epoch=given number of forward and backward steps).
88 | False implies always run until wanted number of features.
89 | shuffle: bool
90 | Whether to shuffle the clean trials before splitting them into
91 | folds. False implies folds are time-blocks, True implies folds are
92 | random mixes of trials of the entire file.
93 | """
94 |
95 | def __init__(self,
96 | # signal-related inputs
97 | cnt,
98 | clean_trial_mask,
99 | name_to_start_codes,
100 | random_state,
101 | name_to_stop_codes=None,
102 | epoch_ival_ms=(-500, 4000),
103 | cross_subject_object=None,
104 |
105 | # bank filter-related inputs
106 | min_freq=0,
107 | max_freq=12,
108 | window=6,
109 | overlap=3,
110 | filt_order=3,
111 |
112 | # machine learning-related inputs
113 | n_folds=5,
114 | fold_file=None,
115 | n_top_bottom_csp_filters=None,
116 | n_selected_filterbands=None,
117 | n_selected_features=None,
118 | forward_steps=2,
119 | backward_steps=1,
120 | stop_when_no_improvement=False,
121 | shuffle=False,
122 | average_trial_covariance=True):
123 | # signal-related inputs
124 | self.cnt = cnt
125 | self.clean_trial_mask = clean_trial_mask
126 | self.epoch_ival_ms = epoch_ival_ms
127 | self.name_to_start_codes = name_to_start_codes
128 | self.name_to_stop_codes = name_to_stop_codes
129 | self.random_state = random_state
130 | if cross_subject_object is None:
131 | self.cross_subject_object = None
132 | self.cross_subject_computation = False
133 | else:
134 | self.cross_subject_object = cross_subject_object
135 | self.cross_subject_computation = True
136 |
137 | # bank filter-related inputs
138 | self.min_freq = min_freq
139 | self.max_freq = max_freq
140 | self.window = window
141 | self.overlap = overlap
142 | self.filt_order = filt_order
143 |
144 | # machine learning-related inputs
145 | self.n_folds = n_folds
146 | self.n_top_bottom_csp_filters = n_top_bottom_csp_filters
147 | self.n_selected_filterbands = n_selected_filterbands
148 | self.n_selected_features = n_selected_features
149 | self.forward_steps = forward_steps
150 | self.backward_steps = backward_steps
151 | self.stop_when_no_improvement = stop_when_no_improvement
152 | self.shuffle = shuffle
153 | self.average_trial_covariance = average_trial_covariance
154 | if fold_file is None:
155 | self.fold_file = None
156 | self.load_fold_from_file = False
157 | else:
158 | self.fold_file = fold_file
159 | self.load_fold_from_file = True
160 |
161 | # other fundamental properties (they will be filled in further
162 | # computational steps)
163 | self.filterbank_csp = None
164 | self.class_pairs = None
165 | self.folds = None
166 | self.binary_csp = None
167 | self.filterbands = None
168 | self.multi_class = None
169 |
170 | # computing other properties for further computation
171 | self.n_classes = len(self.name_to_start_codes)
172 | self.class_pairs = list(combinations(range(self.n_classes), 2))
173 | self.n_trials = self.clean_trial_mask.astype(npint).sum()
174 |
175 | def create_filter_bank(self):
176 | self.filterbands = FilterBank(
177 | min_freq=self.min_freq,
178 | max_freq=self.max_freq,
179 | window=self.window,
180 | overlap=self.overlap
181 | )
182 |
183 | def create_folds(self):
184 | if self.cross_subject_computation is True:
185 | # in case of cross-subject computation
186 | folds = [
187 | arange(
188 | self.cross_subject_object.subject_indexes[x][0],
189 | self.cross_subject_object.subject_indexes[x][1]
190 | )
191 | for x in range(len(self.cross_subject_object.subject_indexes))
192 | ]
193 | self.n_folds = len(folds)
194 | self.folds = [
195 | {
196 | 'train': setdiff1d(arange(self.n_trials), fold),
197 | 'test': fold
198 | }
199 | for fold in folds
200 | ]
201 | elif self.load_fold_from_file is True:
202 | # in case of pre-batched computation
203 | self.folds = np.load(self.fold_file)['folds']
204 | elif self.n_folds == 0:
205 | self.n_folds = 1
206 |
207 | # creating schirrmeister fold
208 | all_idxs = np.array(range(len(self.clean_trial_mask)))
209 | self.folds = [
210 | {
211 | 'train': all_idxs[:-160],
212 | 'test': all_idxs[-160:]
213 | }
214 | ]
215 | self.folds[0]['train'] = self.folds[0]['train'][
216 | self.clean_trial_mask[:-160]]
217 | self.folds[0]['test'] = self.folds[0]['test'][
218 | self.clean_trial_mask[-160:]]
219 | else:
220 | # getting pseudo-random folds
221 | folds = get_balanced_batches(
222 | n_trials=self.n_trials,
223 | rng=self.random_state,
224 | shuffle=self.shuffle,
225 | n_batches=self.n_folds
226 | )
227 | self.folds = [
228 | {
229 | 'train': setdiff1d(arange(self.n_trials), fold),
230 | 'test': fold
231 | }
232 | for fold in folds
233 | ]
234 |
235 | def run(self):
236 | # printing routine start
237 | print_manager(
238 | 'INIT TRAINING ROUTINE',
239 | 'double-dashed',
240 | )
241 |
242 | # creating filter bank
243 | print_manager('Creating filter bank...')
244 | self.create_filter_bank()
245 | print_manager('DONE!!', bottom_return=1)
246 |
247 | # creating folds
248 | print_manager('Creating folds...')
249 | self.create_folds()
250 | print_manager('DONE!!', 'last')
251 |
252 | # running binary FBCSP
253 | print_manager("RUNNING BINARY FBCSP rLDA",
254 | 'double-dashed',
255 | top_return=1)
256 | self.binary_csp = BinaryFBCSP(
257 | cnt=self.cnt,
258 | clean_trial_mask=self.clean_trial_mask,
259 | filterbands=self.filterbands,
260 | filt_order=self.filt_order,
261 | folds=self.folds,
262 | class_pairs=self.class_pairs,
263 | epoch_ival_ms=self.epoch_ival_ms,
264 | n_filters=self.n_top_bottom_csp_filters,
265 | marker_def=self.name_to_start_codes,
266 | name_to_stop_codes=self.name_to_stop_codes,
267 | average_trial_covariance=self.average_trial_covariance
268 | )
269 | self.binary_csp.run()
270 |
271 | # at the very end of the binary CSP experiment, running the real one
272 | print_manager("RUNNING FBCSP rLDA", 'double-dashed', top_return=1)
273 | self.filterbank_csp = FBCSP(
274 | binary_csp=self.binary_csp,
275 | n_features=self.n_selected_features,
276 | n_filterbands=self.n_selected_filterbands,
277 | forward_steps=self.forward_steps,
278 | backward_steps=self.backward_steps,
279 | stop_when_no_improvement=self.stop_when_no_improvement
280 | )
281 | self.filterbank_csp.run()
282 |
283 | # and finally multiclass
284 | print_manager("RUNNING MULTICLASS", 'double-dashed', top_return=1)
285 | self.multi_class = MultiClassWeightedVoting(
286 | train_labels=self.binary_csp.train_labels_full_fold,
287 | test_labels=self.binary_csp.test_labels_full_fold,
288 | train_preds=self.filterbank_csp.train_pred_full_fold,
289 | test_preds=self.filterbank_csp.test_pred_full_fold,
290 | class_pairs=self.class_pairs)
291 | self.multi_class.run()
292 | print('\n')
293 |
294 |
295 | class DLExperiment(object):
296 | """
297 | # TODO: a description for this class
298 | """
299 |
300 | def __init__(self,
301 | # non-default inputs
302 | dataset,
303 | model_name,
304 | results_dir,
305 | subj_results_dir,
306 | name_to_start_codes,
307 | random_state,
308 | fold_idx,
309 |
310 | # hyperparameters
311 | dropout_rate=0.5,
312 | learning_rate=0.001,
313 | batch_size=128,
314 | epochs=10,
315 | early_stopping=False,
316 | monitor='val_acc',
317 | min_delta=0.0001,
318 | patience=5,
319 | loss='categorical_crossentropy',
320 | optimizer='Adam',
321 | shuffle='False',
322 | crop_sample_size=None,
323 | crop_step=None,
324 |
325 | # other parameters
326 | subject_id=1,
327 | data_generator=False,
328 | workers=cpu_count(),
329 | save_model_at_each_epoch=False):
330 | # non-default inputs
331 | self.dataset = dataset
332 | self.model_name = model_name
333 | self.results_dir = results_dir
334 | self.subj_results_dir = subj_results_dir
335 | self.datetime_results_dir = dirname(subj_results_dir)
336 | self.name_to_start_codes = name_to_start_codes
337 | self.random_state = random_state
338 | self.fold_idx = fold_idx
339 |
340 | # hyperparameters
341 | self.dropout_rate = dropout_rate
342 | self.learning_rate = learning_rate
343 | self.batch_size = batch_size
344 | self.epochs = epochs
345 | self.early_stopping = early_stopping
346 | self.monitor = monitor
347 | self.min_delta = min_delta
348 | self.patience = patience
349 | self.loss = loss
350 | self.optimizer = optimizer
351 | self.shuffle = shuffle
352 | if crop_sample_size is None:
353 | self.crop_sample_size = self.n_samples
354 | self.crop_step = 1
355 | else:
356 | self.crop_sample_size = crop_sample_size
357 | self.crop_step = crop_step
358 |
359 | # other parameters
360 | self.subject_id = subject_id
361 | self.data_generator = data_generator
362 | self.workers = workers
363 | self.save_model_at_each_epoch = save_model_at_each_epoch
364 | self.metrics_tracker = None
365 |
366 | # managing paths
367 | self.dl_results_dir = None
368 | self.model_results_dir = None
369 | self.fold_results_dir = None
370 | self.statistics_dir = None
371 | self.figures_dir = None
372 | self.tables_dir = None
373 | self.model_picture_path = None
374 | self.model_report_path = None
375 | self.train_report_path = None
376 | self.h5_models_dir = None
377 | self.h5_model_path = None
378 | self.log_path = None
379 | self.fold_stats_path = None
380 | self.paths_manager()
381 |
382 | # importing model
383 | print_manager('IMPORTING & COMPILING MODEL', 'double-dashed')
384 | model_inputs_str = ', '.join([str(i) for i in [self.n_classes,
385 | self.n_channels,
386 | self.crop_sample_size,
387 | self.dropout_rate]])
388 | expression = 'models.' + self.model_name + '(' + model_inputs_str + ')'
389 | self.model = eval(expression)
390 |
391 | # creating optimizer instance
392 | if self.optimizer is 'Adam':
393 | opt = optimizers.Adam(lr=self.learning_rate)
394 | else:
395 | opt = optimizers.Adam(lr=self.learning_rate)
396 |
397 | # compiling model
398 | self.model.compile(loss=self.loss,
399 | optimizer=opt,
400 | metrics=['accuracy'])
401 | self.model.summary()
402 | print_manager('DONE!!', print_style='last', bottom_return=1)
403 |
404 | def __repr__(self):
405 | return ''.format(self.model_name)
406 |
407 | def __str__(self):
408 | return ''.format(self.model_name)
409 |
410 | def __len__(self):
411 | return len(self.dataset)
412 |
413 | @property
414 | def shape(self):
415 | return self.dataset.shape
416 |
417 | @property
418 | def train_frac(self):
419 | return self.dataset.train_frac
420 |
421 | @property
422 | def valid_frac(self):
423 | return self.dataset.valid_frac
424 |
425 | @property
426 | def test_frac(self):
427 | return self.dataset.test_frac
428 |
429 | @property
430 | def n_classes(self):
431 | return len(self.name_to_start_codes)
432 |
433 | @property
434 | def n_channels(self):
435 | return self.dataset.n_channels
436 |
437 | @property
438 | def n_samples(self):
439 | return self.dataset.n_samples
440 |
441 | def paths_manager(self):
442 | # results_dir is: .../results/hgdecode
443 | # dl_results_dir is: .../results/hgdecode/dl
444 | dl_results_dir = join(self.results_dir, 'dl')
445 |
446 | # model_results_dir is: .../results/hgdecode/dl/model_name
447 | model_results_dir = join(dl_results_dir, self.model_name)
448 |
449 | # fold_results_dir is .../results/dataset/dl/model/datetime/subj/fold
450 | fold_str = str(self.fold_idx + 1)
451 | if len(fold_str) == 1:
452 | fold_str = '0' + fold_str
453 | fold_str = 'fold' + fold_str
454 | fold_results_dir = join(self.subj_results_dir, fold_str)
455 |
456 | # setting on object self
457 | self.dl_results_dir = dl_results_dir
458 | self.model_results_dir = model_results_dir
459 | self.fold_results_dir = fold_results_dir
460 |
461 | # touching only the last directory will be create also the other ones
462 | touch_dir(self.fold_results_dir)
463 |
464 | # statistics_dir is: .../results/hgdecode/dl/model/datetime/statistics
465 | statistics_dir = join(self.datetime_results_dir, 'statistics')
466 |
467 | # figures_dir is: .../results/hgdecode/dl/model/dt/stat/figures/subject
468 | figures_dir = join(statistics_dir, 'figures',
469 | my_formatter(self.subject_id, 'subj'))
470 |
471 | # tables_dir is: .../results/hgdecode/dl/model/datetime/stat/tables
472 | tables_dir = join(statistics_dir, 'tables')
473 |
474 | # setting on object self
475 | self.statistics_dir = statistics_dir
476 | self.figures_dir = figures_dir
477 | self.tables_dir = tables_dir
478 | touch_dir(figures_dir)
479 | touch_dir(tables_dir)
480 |
481 | # files in datetime_results_dir
482 | self.model_report_path = join(self.datetime_results_dir,
483 | 'model_report.txt')
484 | self.model_picture_path = join(self.datetime_results_dir,
485 | 'model_picture.png')
486 |
487 | # files in subj_results_dir
488 | self.log_path = join(self.subj_results_dir, 'log.bin')
489 |
490 | # files in fold_results_dir
491 | self.train_report_path = join(self.fold_results_dir,
492 | 'train_report.csv')
493 | self.fold_stats_path = join(self.fold_results_dir, 'fold_stats.pickle')
494 |
495 | # if the user want to save the model on each epoch...
496 | if self.save_model_at_each_epoch:
497 | # ...creating models directory and an iterable name, else...
498 | self.h5_models_dir = join(self.fold_results_dir, 'h5_models')
499 | touch_dir(self.h5_models_dir)
500 | self.h5_model_path = join(self.h5_models_dir, 'net{epoch:02d}.h5')
501 | else:
502 | # ...pointing to the same results directory
503 | self.h5_model_path = join(self.fold_results_dir,
504 | 'net_best_val_loss.h5')
505 |
506 | def train(self):
507 | # saving a model picture
508 | # TODO: model_pic.png saving routine
509 |
510 | # saving a model report
511 | with open(self.model_report_path, 'w') as mr:
512 | self.model.summary(print_fn=lambda x: mr.write(x + '\n'))
513 |
514 | # pre-allocating callbacks list
515 | callbacks = []
516 |
517 | # saving a train report
518 | csv = CSVLogger(self.train_report_path)
519 | callbacks.append(csv)
520 |
521 | # saving model each epoch
522 | if self.save_model_at_each_epoch:
523 | mcp = ModelCheckpoint(self.h5_model_path)
524 | callbacks.append(mcp)
525 | # else:
526 | # mcp = ModelCheckpoint(self.h5_model_path,
527 | # monitor='val_loss',
528 | # save_best_only=True)
529 | # callbacks.append(mcp)
530 |
531 | # if early_stopping is True...
532 | if self.early_stopping is True:
533 | # putting epochs to a very large number
534 | epochs = 1000
535 |
536 | # creating early stopping callback
537 | esc = EarlyStopping(monitor=self.monitor,
538 | min_delta=self.min_delta,
539 | patience=self.patience,
540 | verbose=1)
541 | callbacks.append(esc)
542 | else:
543 | # getting user defined epochs value
544 | epochs = self.epochs
545 |
546 | # using fit_generator if a data generator is required
547 | if self.data_generator is True:
548 | training_generator = EEGDataGenerator(self.dataset.X_train,
549 | self.dataset.y_train,
550 | self.batch_size,
551 | self.n_classes,
552 | self.crop_sample_size,
553 | self.crop_step)
554 | validation_generator = EEGDataGenerator(self.dataset.X_train,
555 | self.dataset.y_train,
556 | self.batch_size,
557 | self.n_classes,
558 | self.crop_sample_size,
559 | self.crop_step)
560 |
561 | # training!
562 | print_manager(
563 | 'RUNNING TRAINING ON FOLD {}'.format(self.fold_idx + 1),
564 | 'double-dashed'
565 | )
566 | self.model.fit_generator(generator=training_generator,
567 | validation_data=validation_generator,
568 | use_multiprocessing=True,
569 | workers=self.workers,
570 | epochs=epochs,
571 | verbose=1,
572 | callbacks=callbacks)
573 | else:
574 | # creating crops
575 | self.dataset.make_crops(self.crop_sample_size, self.crop_step)
576 |
577 | # forcing the x examples to have 4 dimensions
578 | self.dataset.add_axis()
579 |
580 | # parsing y to categorical
581 | self.dataset.to_categorical()
582 |
583 | # TODO: MetricsTracker for Data Generation routine
584 | # creating a MetricsTracker instance
585 | if self.metrics_tracker is None:
586 | callbacks.append(
587 | MetricsTracker(
588 | dataset=self.dataset,
589 | epochs=self.epochs,
590 | n_classes=self.n_classes,
591 | batch_size=self.batch_size,
592 | h5_model_path=self.h5_model_path,
593 | fold_stats_path=self.fold_stats_path
594 | )
595 | )
596 | else:
597 | callbacks.append(self.metrics_tracker)
598 |
599 | # training!
600 | print_manager(
601 | 'RUNNING TRAINING ON FOLD {}'.format(self.fold_idx + 1),
602 | 'double-dashed'
603 | )
604 | self.model.fit(x=self.dataset.X_train,
605 | y=self.dataset.y_train,
606 | validation_data=(self.dataset.X_valid,
607 | self.dataset.y_valid),
608 | batch_size=self.batch_size,
609 | epochs=epochs,
610 | verbose=1,
611 | callbacks=callbacks,
612 | shuffle=self.shuffle)
613 | # TODO: if validation_frac is 0 or None, not to split train and test
614 | # to train the epochs hyperparameter.
615 |
616 | def test(self):
617 | # TODO: evaluate_generator if data_generator is True
618 | # loading best net
619 | self.model.load_weights(self.h5_model_path)
620 |
621 | # computing loss and other metrics
622 | score = self.model.evaluate(
623 | self.dataset.X_test,
624 | self.dataset.y_test,
625 | verbose=1
626 | )
627 |
628 | print('Test loss:', score[0])
629 | print('Test acc:', score[1])
630 |
631 | # making predictions on X_test with final model and getting also
632 | # y_test from memory; parsing both back from categorical
633 | y_pred = self.model.predict(self.dataset.X_test).argmax(axis=1)
634 | if self.data_generator is True:
635 | y_test = self.dataset.y_test
636 | else:
637 | y_test = self.dataset.y_test.argmax(axis=1)
638 |
639 | # computing confusion matrix
640 | conf_mtx = confusion_matrix(y_true=y_test, y_pred=y_pred)
641 | print("Confusion matrix:\n", conf_mtx)
642 |
643 | def prepare_for_transfer_learning(self,
644 | cross_subj_dir_path,
645 | subject_id,
646 | train_anyway=False):
647 | # printing the start
648 | print_manager('PREPARING FOR TRANSFER LEARNING', 'double-dashed')
649 |
650 | # getting this subject cross-subject dir
651 | cross_subj_this_subj_dir_path = join(cross_subj_dir_path,
652 | 'subj_cross',
653 | my_formatter(subject_id, 'fold'))
654 |
655 | # loading
656 | self.model.load_weights(join(cross_subj_this_subj_dir_path,
657 | 'net_best_val_loss.h5'))
658 |
659 | if train_anyway is False:
660 | # pre-saving this net as best one
661 | self.model.save(self.h5_model_path)
662 |
663 | # creating metrics tracker instance
664 | self.metrics_tracker = MetricsTracker(
665 | dataset=self.dataset,
666 | epochs=self.epochs,
667 | n_classes=self.n_classes,
668 | batch_size=self.batch_size,
669 | h5_model_path=self.h5_model_path,
670 | fold_stats_path=self.fold_stats_path
671 | )
672 |
673 | # loading cross-subject info
674 | with open(join(cross_subj_this_subj_dir_path,
675 | 'fold_stats.pickle'), 'rb') as f:
676 | results = load(f)
677 |
678 | # forcing best net to be the 0 one
679 | self.metrics_tracker.best['loss'] = results['test']['loss']
680 | self.metrics_tracker.best['acc'] = results['test']['acc']
681 | self.metrics_tracker.best['idx'] = 0
682 |
683 | # printing the end
684 | print_manager('DONE!!', print_style='last', bottom_return=1)
685 |
686 | def freeze_layers(self, layers_to_freeze):
687 | print_manager('FREEZING LAYERS', 'double-dashed')
688 | if layers_to_freeze == 0:
689 | print('NOTHING TO FREEZE!!')
690 | else:
691 | print("I'm gonna gonna freeze {} layers.".format(layers_to_freeze))
692 |
693 | # freezing layers
694 | frozen = 0
695 | if layers_to_freeze > 0:
696 | idx = 0
697 | step = 1
698 | else:
699 | idx = -1
700 | step = -1
701 | layers_to_freeze = - layers_to_freeze
702 | while frozen < layers_to_freeze:
703 | layer = self.model.layers[idx]
704 | if layer.name[:4] == 'conv' or layer.name[:5] == 'dense':
705 | layer.trainable = False
706 | frozen += 1
707 | idx += step
708 |
709 | # creating optimizer instance
710 | if self.optimizer is 'Adam':
711 | opt = optimizers.Adam(lr=self.learning_rate)
712 | else:
713 | opt = optimizers.Adam(lr=self.learning_rate)
714 |
715 | # compiling model
716 | self.model.compile(loss=self.loss,
717 | optimizer=opt,
718 | metrics=['accuracy'])
719 |
720 | # printing model information
721 | self.model.summary()
722 | print_manager('DONE!!', print_style='last', bottom_return=1)
723 |
--------------------------------------------------------------------------------
/hgdecode/fbcsprlda.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import numpy as np
3 | import logging as log
4 | from copy import deepcopy
5 | from numpy import empty, mean, array
6 | from hgdecode.lda import lda_apply
7 | from hgdecode.lda import lda_train_scaled
8 | from hgdecode.signalproc import bandpass_mne
9 | from hgdecode.signalproc import select_trials
10 | from hgdecode.signalproc import calculate_csp
11 | from hgdecode.signalproc import select_classes
12 | from hgdecode.signalproc import apply_csp_var_log
13 | from hgdecode.signalproc import concatenate_channels
14 | from braindecode.datautil.iterators import get_balanced_batches
15 | from braindecode.datautil.trial_segment import \
16 | create_signal_target_from_raw_mne
17 |
18 |
19 | class BinaryFBCSP(object):
20 | """
21 | # TODO: a description for this class
22 | """
23 |
24 | def __init__(self,
25 | cnt,
26 | clean_trial_mask,
27 | filterbands,
28 | filt_order,
29 | folds,
30 | class_pairs,
31 | epoch_ival_ms,
32 | n_filters,
33 | marker_def,
34 | name_to_stop_codes=None,
35 | average_trial_covariance=False):
36 | # cnt and signal parameters
37 | self.cnt = cnt
38 | self.clean_trial_mask = clean_trial_mask
39 | self.epoch_ival_ms = epoch_ival_ms
40 | self.marker_def = marker_def
41 | self.name_to_stop_codes = name_to_stop_codes
42 |
43 | # filter bank parameters
44 | self.filterbands = filterbands
45 | self.filt_order = filt_order
46 |
47 | # machine learning parameters
48 | self.folds = folds
49 | self.n_filters = n_filters
50 | self.class_pairs = class_pairs
51 | self.average_trial_covariance = average_trial_covariance
52 |
53 | # getting result shape
54 | n_filterbands = len(self.filterbands)
55 | n_folds = len(self.folds)
56 | n_class_pairs = len(self.class_pairs)
57 | result_shape = (n_filterbands, n_folds, n_class_pairs)
58 |
59 | # creating result related properties and pre-allocating them
60 | self.filters = empty(result_shape, dtype=object)
61 | self.patterns = empty(result_shape, dtype=object)
62 | self.variances = empty(result_shape, dtype=object)
63 | self.train_feature = empty(result_shape, dtype=object)
64 | self.test_feature = empty(result_shape, dtype=object)
65 | self.train_feature_full_fold = empty(result_shape, dtype=object)
66 | self.test_feature_full_fold = empty(result_shape, dtype=object)
67 | self.clf = empty(result_shape, dtype=object)
68 | self.train_accuracy = empty(result_shape, dtype=object)
69 | self.test_accuracy = empty(result_shape, dtype=object)
70 | self.train_labels_full_fold = empty(len(self.folds), dtype=object)
71 | self.test_labels_full_fold = empty(len(self.folds), dtype=object)
72 | self.train_labels = empty(
73 | (len(self.folds), len(self.class_pairs)),
74 | dtype=object)
75 | self.test_labels = empty(
76 | (len(self.folds), len(self.class_pairs)),
77 | dtype=object)
78 |
79 | def run(self):
80 | # %% CYCLING ON FILTERS IN FILTERBANK
81 | # %%
82 | # just for me: enumerate is a really powerful built-in python
83 | # function that allows you to loop over something and have an
84 | # automatic counter. In this case, bp_nr is the counter,
85 | # then filt_band is the default exit for the method getitem for
86 | # filterbands class.
87 | for bp_nr, filt_band in enumerate(self.filterbands):
88 | # printing filter information
89 | self.print_filter(bp_nr)
90 |
91 | # bandpassing all the cnt RawArray with the current filter
92 | bandpassed_cnt = bandpass_mne(
93 | self.cnt,
94 | filt_band[0],
95 | filt_band[1],
96 | filt_order=self.filt_order
97 | )
98 |
99 | # epoching: from cnt data to epoched data
100 | epo = create_signal_target_from_raw_mne(
101 | bandpassed_cnt,
102 | name_to_start_codes=self.marker_def,
103 | epoch_ival_ms=self.epoch_ival_ms,
104 | name_to_stop_codes=self.name_to_stop_codes
105 | )
106 |
107 | # cleaning epoched data with clean_trial_mask (finally)
108 | if len(self.folds) != 1:
109 | epo.X = epo.X[self.clean_trial_mask]
110 | epo.y = epo.y[self.clean_trial_mask]
111 |
112 | # %% CYCLING ON FOLDS
113 | # %%
114 | for fold_nr in range(len(self.folds)):
115 | # printing fold information
116 | self.print_fold_nr(fold_nr)
117 |
118 | # getting information on current fold
119 | train_test = self.folds[fold_nr]
120 |
121 | # getting train and test indexes
122 | train_ind = train_test['train']
123 | test_ind = train_test['test']
124 |
125 | # getting train data from train indexes
126 | epo_train = select_trials(epo, train_ind)
127 |
128 | # getting test data from test indexes
129 | epo_test = select_trials(epo, test_ind)
130 |
131 | # logging info on train
132 | log.info("#Train trials: {:4d}".format(len(epo_train.X)))
133 |
134 | # logging info on test
135 | log.info("#Test trials : {:4d}".format(len(epo_test.X)))
136 |
137 | # setting train labels of this fold
138 | self.train_labels_full_fold[fold_nr] = epo_train.y
139 |
140 | # setting test labels of this fold
141 | self.test_labels_full_fold[fold_nr] = epo_test.y
142 |
143 | # %% CYCLING ON ALL POSSIBLE CLASS PAIRS
144 | # %%
145 | for pair_nr in range(len(self.class_pairs)):
146 | # getting class pair from index (pair_nr)
147 | class_pair = self.class_pairs[pair_nr]
148 |
149 | # printing class pair information
150 | self.print_class_pair(class_pair)
151 |
152 | # getting train trials only for current two classes
153 | epo_train_pair = select_classes(epo_train, class_pair)
154 |
155 | # getting test trials only for current two classes
156 | epo_test_pair = select_classes(epo_test, class_pair)
157 |
158 | # saving train labels for this two classes
159 | self.train_labels[fold_nr][pair_nr] = epo_train_pair.y
160 |
161 | # saving test labels for this two classes
162 | self.test_labels[fold_nr][pair_nr] = epo_test_pair.y
163 |
164 | # %% COMPUTING CSP
165 | # %%
166 | filters, patterns, variances = calculate_csp(
167 | epo_train_pair,
168 | average_trial_covariance=self.average_trial_covariance
169 | )
170 |
171 | # %% FEATURE EXTRACTION
172 | # %%
173 | # choosing how many spacial filter to apply;
174 | # if no spacial filter number specified...
175 | if self.n_filters is None:
176 | # ...taking all columns, else...
177 | columns = list(range(len(filters)))
178 | else:
179 | # ...take topmost and bottommost filters;
180 | # e.g. for n_filters=3 we are going to pick:
181 | # 0, 1, 2, -3, -2, -1
182 | columns = (list(range(0, self.n_filters)) +
183 | list(range(-self.n_filters, 0)))
184 |
185 | # feature extraction on train
186 | train_feature = apply_csp_var_log(
187 | epo_train_pair,
188 | filters,
189 | columns
190 | )
191 |
192 | # feature extraction on test
193 | test_feature = apply_csp_var_log(
194 | epo_test_pair,
195 | filters,
196 | columns
197 | )
198 |
199 | # %% COMPUTING LDA USING TRAIN FEATURES
200 | # %%
201 | # clf is a 1x2 tuple where:
202 | # * clf[0] is hyperplane parameters
203 | # * clf[1] is hyperplane bias
204 | # with clf, you can recreate the n-dimensional
205 | # hyperplane that splits class space, so you can
206 | # classify your fbcsp extracted features.
207 | clf = lda_train_scaled(train_feature, shrink=True)
208 |
209 | # %% APPLYING LDA ON TRAIN
210 | # %%
211 | # applying LDA
212 | train_out = lda_apply(train_feature, clf)
213 |
214 | # getting true/false labels instead of class labels
215 | # for example, if you have:
216 | # train_feature.y --> [1, 3, 3, 1]
217 | # class_pair --> [1, 3]
218 | # so you will have:
219 | # true_0_1_labels_train = [False, True, True, False]
220 | true_0_1_labels_train = train_feature.y == class_pair[1]
221 |
222 | # if predicted output grater than 0 True, False instead
223 | predicted_train = train_out >= 0
224 |
225 | # computing accuracy
226 | # if mean has a boolean array as input, it will
227 | # compute number of True elements divided by total
228 | # number of elements, so the accuracy
229 | train_accuracy = mean(
230 | true_0_1_labels_train == predicted_train
231 | )
232 |
233 | # %% APPLYING LDA ON TEST
234 | # %%
235 | # same procedure
236 | test_out = lda_apply(test_feature, clf)
237 | true_0_1_labels_test = test_feature.y == class_pair[1]
238 | predicted_test = test_out >= 0
239 | test_accuracy = mean(
240 | true_0_1_labels_test == predicted_test
241 | )
242 |
243 | # %% FEATURE COMPUTATION FOR FULL FOLD
244 | # %% (FOR LATER MULTICLASS)
245 | # here we use csp computed only for this pair of classes
246 | # to compute feature for all the current fold
247 | # train here
248 | train_feature_full_fold = apply_csp_var_log(
249 | epo_train,
250 | filters,
251 | columns
252 | )
253 |
254 | # test here
255 | test_feature_full_fold = apply_csp_var_log(
256 | epo_test,
257 | filters,
258 | columns
259 | )
260 |
261 | # %% STORE RESULTS
262 | # %%
263 | # only store used patterns filters variances
264 | # to save memory space on disk
265 | self.store_results(
266 | bp_nr,
267 | fold_nr,
268 | pair_nr,
269 | filters[:, columns],
270 | patterns[:, columns],
271 | variances[columns],
272 | train_feature,
273 | test_feature,
274 | train_feature_full_fold,
275 | test_feature_full_fold,
276 | clf,
277 | train_accuracy,
278 | test_accuracy
279 | )
280 |
281 | # printing the end of this super-nested cycle
282 | self.print_results(bp_nr, fold_nr, pair_nr)
283 |
284 | # printing a blank line to divide filters
285 | print()
286 |
287 | def store_results(self,
288 | bp_nr,
289 | fold_nr,
290 | pair_nr,
291 | filters,
292 | patterns,
293 | variances,
294 | train_feature,
295 | test_feature,
296 | train_feature_full_fold,
297 | test_feature_full_fold,
298 | clf,
299 | train_accuracy,
300 | test_accuracy):
301 | """ Store all supplied arguments to this objects dict, at the correct
302 | indices for filterband / fold / class_pair."""
303 | local_vars = locals()
304 | del local_vars['self']
305 | del local_vars['bp_nr']
306 | del local_vars['fold_nr']
307 | del local_vars['pair_nr']
308 | for var in local_vars:
309 | self.__dict__[var][bp_nr, fold_nr, pair_nr] = local_vars[var]
310 |
311 | def print_filter(self, bp_nr):
312 | # distinguish filter blocks by empty line
313 | log.info(
314 | "Filter {:d}/{:d}, {:4.2f} to {:4.2f} Hz".format(
315 | bp_nr + 1,
316 | len(self.filterbands),
317 | *self.filterbands[bp_nr])
318 | )
319 |
320 | @staticmethod
321 | def print_fold_nr(fold_nr):
322 | log.info("Fold Nr: {:d}".format(fold_nr + 1))
323 |
324 | @staticmethod
325 | def print_class_pair(class_pair):
326 | class_pair_plus_one = (array(class_pair) + 1).tolist()
327 | log.info("Class {:d} vs {:d}".format(*class_pair_plus_one))
328 |
329 | def print_results(self, bp_nr, fold_nr, pair_nr):
330 | log.info("Train: {:4.2f}%".format(
331 | self.train_accuracy[bp_nr, fold_nr, pair_nr] * 100))
332 | log.info("Test: {:4.2f}%".format(
333 | self.test_accuracy[bp_nr, fold_nr, pair_nr] * 100))
334 |
335 |
336 | class FBCSP(object):
337 | """
338 | # TODO: a description for this class
339 | """
340 |
341 | def __init__(self,
342 | binary_csp,
343 | n_features=None,
344 | n_filterbands=None,
345 | forward_steps=2,
346 | backward_steps=1,
347 | stop_when_no_improvement=False):
348 | # copying inputs
349 | self.binary_csp = binary_csp
350 | self.n_features = n_features
351 | self.n_filterbands = n_filterbands
352 | self.forward_steps = forward_steps
353 | self.backward_steps = backward_steps
354 | self.stop_when_no_improvement = stop_when_no_improvement
355 |
356 | # pre-allocating other properties
357 | self.train_feature = None
358 | self.train_feature_full_fold = None
359 | self.test_feature = None
360 | self.test_feature_full_fold = None
361 | self.selected_filter_inds = None
362 | self.selected_filters_per_filterband = None
363 | self.selected_features = None
364 | self.clf = None
365 | self.train_accuracy = None
366 | self.test_accuracy = None
367 | self.train_pred_full_fold = None
368 | self.test_pred_full_fold = None
369 |
370 | def run(self):
371 | self.select_filterbands()
372 | if self.n_features is not None:
373 | log.info("Run feature selection...")
374 | self.collect_best_features()
375 | log.info("Done.")
376 | else:
377 | self.collect_features()
378 | self.train_classifiers()
379 | self.predict_outputs()
380 |
381 | def select_filterbands(self):
382 | n_all_filterbands = len(self.binary_csp.filterbands)
383 | if self.n_filterbands is None:
384 | self.selected_filter_inds = list(range(n_all_filterbands))
385 | else:
386 | # Select the filterbands with the highest mean accuracy on the
387 | # training sets
388 | mean_accs = np.mean(self.binary_csp.train_accuracy, axis=(1, 2))
389 | best_filters = np.argsort(mean_accs)[::-1][:self.n_filterbands]
390 | self.selected_filter_inds = best_filters
391 |
392 | def collect_features(self):
393 | n_folds = len(self.binary_csp.folds)
394 | n_class_pairs = len(self.binary_csp.class_pairs)
395 | result_shape = (n_folds, n_class_pairs)
396 | self.train_feature = np.empty(result_shape, dtype=object)
397 | self.train_feature_full_fold = np.empty(result_shape, dtype=object)
398 | self.test_feature = np.empty(result_shape, dtype=object)
399 | self.test_feature_full_fold = np.empty(result_shape, dtype=object)
400 |
401 | bcsp = self.binary_csp # just to make code shorter
402 | filter_inds = self.selected_filter_inds
403 | for fold_i in range(n_folds):
404 | for class_i in range(n_class_pairs):
405 | self.train_feature[fold_i, class_i] = concatenate_channels(
406 | bcsp.train_feature[filter_inds, fold_i, class_i])
407 | self.train_feature_full_fold[fold_i, class_i] = (
408 | concatenate_channels(
409 | bcsp.train_feature_full_fold[
410 | filter_inds, fold_i, class_i]))
411 | self.test_feature[fold_i, class_i] = concatenate_channels(
412 | bcsp.test_feature[filter_inds, fold_i, class_i]
413 | )
414 | self.test_feature_full_fold[fold_i, class_i] = (
415 | concatenate_channels(
416 | bcsp.test_feature_full_fold[
417 | filter_inds, fold_i, class_i]
418 | ))
419 |
420 | def collect_best_features(self):
421 | """ Selects features filterwise per filterband, starting with no
422 | features, then selecting the best filterpair from the bestfilterband
423 | (measured on internal train/test split)"""
424 | bincsp = self.binary_csp # just to make code shorter
425 |
426 | # getting dimension for feature arrays
427 | n_folds = len(self.binary_csp.folds)
428 | n_class_pairs = len(self.binary_csp.class_pairs)
429 | result_shape = (n_folds, n_class_pairs)
430 |
431 | # initializing feature array for this classes
432 | self.train_feature = np.empty(result_shape, dtype=object)
433 | self.train_feature_full_fold = np.empty(result_shape, dtype=object)
434 | self.test_feature = np.empty(result_shape, dtype=object)
435 | self.test_feature_full_fold = np.empty(result_shape, dtype=object)
436 | self.selected_filters_per_filterband = np.empty(result_shape,
437 | dtype=object)
438 | # outer cycle on folds
439 | for fold_i in range(n_folds):
440 | # inner cycle on pairs
441 | for class_pair_i in range(n_class_pairs):
442 | # saving bincsp features locally (prevent ram to modify values)
443 | bin_csp_train_features = deepcopy(
444 | bincsp.train_feature[
445 | self.selected_filter_inds, fold_i, class_pair_i
446 | ]
447 | )
448 | bin_csp_train_features_full_fold = deepcopy(
449 | bincsp.train_feature_full_fold[
450 | self.selected_filter_inds,
451 | fold_i, class_pair_i
452 | ]
453 | )
454 | bin_csp_test_features = deepcopy(
455 | bincsp.test_feature[
456 | self.selected_filter_inds,
457 | fold_i,
458 | class_pair_i
459 | ]
460 | )
461 | bin_csp_test_features_full_fold = deepcopy(
462 | bincsp.test_feature_full_fold[
463 | self.selected_filter_inds, fold_i, class_pair_i
464 | ]
465 | )
466 |
467 | # selecting best filters
468 | selected_filters_per_filt = \
469 | self.select_best_filters_best_filterbands(
470 | bin_csp_train_features,
471 | max_features=self.n_features,
472 | forward_steps=self.forward_steps,
473 | backward_steps=self.backward_steps,
474 | stop_when_no_improvement=self.stop_when_no_improvement
475 | )
476 |
477 | # collecting train features
478 | self.train_feature[fold_i, class_pair_i] = \
479 | self.collect_features_for_filter_selection(
480 | bin_csp_train_features,
481 | selected_filters_per_filt
482 | )
483 |
484 | # collecting train features full fold
485 | self.train_feature_full_fold[fold_i, class_pair_i] = \
486 | self.collect_features_for_filter_selection(
487 | bin_csp_train_features_full_fold,
488 | selected_filters_per_filt
489 | )
490 |
491 | # collecting test features
492 | self.test_feature[fold_i, class_pair_i] = \
493 | self.collect_features_for_filter_selection(
494 | bin_csp_test_features,
495 | selected_filters_per_filt
496 | )
497 |
498 | # collecting test features full fold
499 | self.test_feature_full_fold[fold_i, class_pair_i] = \
500 | self.collect_features_for_filter_selection(
501 | bin_csp_test_features_full_fold,
502 | selected_filters_per_filt
503 | )
504 |
505 | # saving also the filters selected for this fold and pair
506 | self.selected_filters_per_filterband[fold_i, class_pair_i] = \
507 | selected_filters_per_filt
508 |
509 | @staticmethod
510 | def select_best_filters_best_filterbands(features,
511 | max_features,
512 | forward_steps,
513 | backward_steps,
514 | stop_when_no_improvement):
515 | n_filterbands = len(features)
516 | n_filters_per_fb = features[0].X.shape[1] / 2
517 | selected_filters_per_band = [0] * n_filterbands
518 | best_selected_filters_per_filterband = None
519 | last_best_accuracy = -1
520 |
521 | # Run until no improvement or max features reached
522 | selection_finished = False
523 | while not selection_finished:
524 | for _ in range(forward_steps):
525 | best_accuracy = -1
526 |
527 | # let's try always taking a feature in each iteration
528 | for filt_i in range(n_filterbands):
529 | this_filt_per_fb = deepcopy(selected_filters_per_band)
530 | if this_filt_per_fb[filt_i] == n_filters_per_fb:
531 | continue
532 | this_filt_per_fb[filt_i] = this_filt_per_fb[filt_i] + 1
533 | all_features = \
534 | FBCSP.collect_features_for_filter_selection(
535 | features,
536 | this_filt_per_fb
537 | )
538 |
539 | # make 5 times cross validation...
540 | test_accuracy = FBCSP.cross_validate_lda(
541 | all_features
542 | )
543 |
544 | if test_accuracy > best_accuracy:
545 | best_accuracy = test_accuracy
546 | best_selected_filters_per_filterband = this_filt_per_fb
547 |
548 | selected_filters_per_band = \
549 | best_selected_filters_per_filterband
550 |
551 | for _ in range(backward_steps):
552 | best_accuracy = -1
553 | # let's try always taking a feature in each iteration
554 | for filt_i in range(n_filterbands):
555 | this_filt_per_fb = deepcopy(selected_filters_per_band)
556 | if this_filt_per_fb[filt_i] == 0:
557 | continue
558 | this_filt_per_fb[filt_i] = this_filt_per_fb[filt_i] - 1
559 | all_features = \
560 | FBCSP.collect_features_for_filter_selection(
561 | features,
562 | this_filt_per_fb
563 | )
564 | # make 5 times cross validation...
565 | test_accuracy = FBCSP.cross_validate_lda(
566 | all_features
567 | )
568 | if test_accuracy > best_accuracy:
569 | best_accuracy = test_accuracy
570 | best_selected_filters_per_filterband = this_filt_per_fb
571 | selected_filters_per_band = \
572 | best_selected_filters_per_filterband
573 |
574 | selection_finished = 2 * np.sum(
575 | selected_filters_per_band) >= max_features
576 | if stop_when_no_improvement:
577 | # there was no improvement if accuracy did not increase...
578 | selection_finished = (selection_finished
579 | or best_accuracy <= last_best_accuracy)
580 | last_best_accuracy = best_accuracy
581 | return selected_filters_per_band
582 |
583 | @staticmethod
584 | def collect_features_for_filter_selection(
585 | features,
586 | filters_for_filterband
587 | ):
588 | n_filters_per_fb = features[0].X.shape[1] // 2
589 | n_filterbands = len(features)
590 | # start with filters of first filterband...
591 | # then add others all together
592 | first_features = deepcopy(features[0])
593 | first_n_filters = filters_for_filterband[0]
594 | if first_n_filters == 0:
595 | first_features.X = first_features.X[:, 0:0]
596 | else:
597 | first_features.X = \
598 | first_features.X[
599 | :,
600 | list(range(first_n_filters)) +
601 | list(range(-first_n_filters, 0))
602 | ]
603 |
604 | all_features = first_features
605 | for i in range(1, n_filterbands):
606 | this_n_filters = min(n_filters_per_fb, filters_for_filterband[i])
607 | if this_n_filters > 0:
608 | next_features = deepcopy(features[i])
609 | if this_n_filters == 0:
610 | next_features.X = next_features.X[0:0]
611 | else:
612 | next_features.X = \
613 | next_features.X[
614 | :,
615 | list(range(this_n_filters)) +
616 | list(range(-this_n_filters, 0))
617 | ]
618 | all_features = concatenate_channels(
619 | (all_features, next_features))
620 | return all_features
621 |
622 | @staticmethod
623 | def cross_validate_lda(features):
624 | n_trials = features.X.shape[0]
625 | folds = get_balanced_batches(n_trials, rng=None, shuffle=False,
626 | n_batches=5)
627 | # make to train-test splits, fold is test part..
628 | folds = [(np.setdiff1d(np.arange(n_trials), fold),
629 | fold) for fold in folds]
630 | test_accuracies = []
631 | for train_inds, test_inds in folds:
632 | train_features = select_trials(features, train_inds)
633 | test_features = select_trials(features, test_inds)
634 | clf = lda_train_scaled(train_features, shrink=True)
635 | test_out = lda_apply(test_features, clf)
636 |
637 | higher_class = np.max(test_features.y)
638 | true_0_1_labels_test = test_features.y == higher_class
639 |
640 | predicted_test = test_out >= 0
641 | test_accuracy = np.mean(true_0_1_labels_test == predicted_test)
642 | test_accuracies.append(test_accuracy)
643 | return np.mean(test_accuracies)
644 |
645 | def train_classifiers(self):
646 | n_folds = len(self.binary_csp.folds)
647 | n_class_pairs = len(self.binary_csp.class_pairs)
648 | self.clf = np.empty((n_folds, n_class_pairs),
649 | dtype=object)
650 | for fold_i in range(n_folds):
651 | for class_i in range(n_class_pairs):
652 | train_feature = self.train_feature[fold_i, class_i]
653 | clf = lda_train_scaled(train_feature, shrink=True)
654 | self.clf[fold_i, class_i] = clf
655 |
656 | def predict_outputs(self):
657 | n_folds = len(self.binary_csp.folds)
658 | n_class_pairs = len(self.binary_csp.class_pairs)
659 | result_shape = (n_folds, n_class_pairs)
660 | self.train_accuracy = np.empty(result_shape, dtype=float)
661 | self.test_accuracy = np.empty(result_shape, dtype=float)
662 | self.train_pred_full_fold = np.empty(result_shape, dtype=object)
663 | self.test_pred_full_fold = np.empty(result_shape, dtype=object)
664 | for fold_i in range(n_folds):
665 | log.info("Fold Nr: {:d}".format(fold_i + 1))
666 | for class_i, class_pair in enumerate(self.binary_csp.class_pairs):
667 | clf = self.clf[fold_i, class_i]
668 | class_pair_plus_one = (np.array(class_pair) + 1).tolist()
669 | log.info("Class {:d} vs {:d}".format(*class_pair_plus_one))
670 | train_feature = self.train_feature[fold_i, class_i]
671 | train_out = lda_apply(train_feature, clf)
672 | true_0_1_labels_train = train_feature.y == class_pair[1]
673 | predicted_train = train_out >= 0
674 | # remove xarray wrapper with float( ...
675 | train_accuracy = float(np.mean(true_0_1_labels_train
676 | == predicted_train))
677 | self.train_accuracy[fold_i, class_i] = train_accuracy
678 |
679 | test_feature = self.test_feature[fold_i, class_i]
680 | test_out = lda_apply(test_feature, clf)
681 | true_0_1_labels_test = test_feature.y == class_pair[1]
682 | predicted_test = test_out >= 0
683 | test_accuracy = float(np.mean(true_0_1_labels_test
684 | == predicted_test))
685 |
686 | self.test_accuracy[fold_i, class_i] = test_accuracy
687 |
688 | train_feature_full_fold = self.train_feature_full_fold[fold_i,
689 | class_i]
690 | train_out_full_fold = lda_apply(train_feature_full_fold, clf)
691 | self.train_pred_full_fold[
692 | fold_i, class_i] = train_out_full_fold
693 | test_feature_full_fold = self.test_feature_full_fold[fold_i,
694 | class_i]
695 | test_out_full_fold = lda_apply(test_feature_full_fold, clf)
696 | self.test_pred_full_fold[fold_i, class_i] = test_out_full_fold
697 |
698 | log.info("Train: {:4.2f}%".format(train_accuracy * 100))
699 | log.info("Test: {:4.2f}%".format(test_accuracy * 100))
700 |
701 |
702 | class MultiClassWeightedVoting(object):
703 | """
704 | # TODO: a description for this class
705 | """
706 |
707 | def __init__(self,
708 | train_labels,
709 | test_labels,
710 | train_preds,
711 | test_preds,
712 | class_pairs):
713 | # copying input parameters
714 | self.train_labels = train_labels
715 | self.test_labels = test_labels
716 | self.train_preds = train_preds
717 | self.test_preds = test_preds
718 | self.class_pairs = class_pairs
719 |
720 | # pre-allocating other useful properties
721 | self.train_class_sums = None
722 | self.test_class_sums = None
723 | self.train_predicted_labels = None
724 | self.test_predicted_labels = None
725 | self.train_accuracy = None
726 | self.test_accuracy = None
727 |
728 | def run(self):
729 | # determine number of classes by number of unique classes
730 | # appearing in class pairs
731 | n_classes = len(np.unique(list(itertools.chain(*self.class_pairs))))
732 | n_folds = len(self.train_labels)
733 | self.train_class_sums = np.empty(n_folds, dtype=object)
734 | self.test_class_sums = np.empty(n_folds, dtype=object)
735 | self.train_predicted_labels = np.empty(n_folds, dtype=object)
736 | self.test_predicted_labels = np.empty(n_folds, dtype=object)
737 | self.train_accuracy = np.ones(n_folds) * np.nan
738 | self.test_accuracy = np.ones(n_folds) * np.nan
739 | for fold_nr in range(n_folds):
740 | log.info("Fold Nr: {:d}".format(fold_nr + 1))
741 | train_labels = self.train_labels[fold_nr]
742 | train_preds = self.train_preds[fold_nr]
743 | train_class_sums = np.zeros((len(train_labels), n_classes))
744 |
745 | test_labels = self.test_labels[fold_nr]
746 | test_preds = self.test_preds[fold_nr]
747 | test_class_sums = np.zeros((len(test_labels), n_classes))
748 | for pair_i, class_pair in enumerate(self.class_pairs):
749 | this_train_preds = train_preds[pair_i]
750 | assert len(this_train_preds) == len(train_labels)
751 | train_class_sums[:, class_pair[0]] -= this_train_preds
752 | train_class_sums[:, class_pair[1]] += this_train_preds
753 | this_test_preds = test_preds[pair_i]
754 | assert len(this_test_preds) == len(test_labels)
755 | test_class_sums[:, class_pair[0]] -= this_test_preds
756 | test_class_sums[:, class_pair[1]] += this_test_preds
757 |
758 | self.train_class_sums[fold_nr] = train_class_sums
759 | self.test_class_sums[fold_nr] = test_class_sums
760 | train_predicted_labels = np.argmax(train_class_sums, axis=1)
761 | test_predicted_labels = np.argmax(test_class_sums, axis=1)
762 | self.train_predicted_labels[fold_nr] = train_predicted_labels
763 | self.test_predicted_labels[fold_nr] = test_predicted_labels
764 | train_accuracy = (np.sum(train_predicted_labels == train_labels) /
765 | float(len(train_labels)))
766 | self.train_accuracy[fold_nr] = train_accuracy
767 | test_accuracy = (np.sum(test_predicted_labels == test_labels) /
768 | float(len(test_labels)))
769 | self.test_accuracy[fold_nr] = test_accuracy
770 | log.info("Train: {:4.2f}%".format(train_accuracy * 100))
771 | log.info("Test: {:4.2f}%".format(test_accuracy * 100))
772 |
--------------------------------------------------------------------------------
/hgdecode/lda.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.covariance import LedoitWolf
3 |
4 |
5 | def lda_train_scaled(fv, shrink=False):
6 | """Train the LDA classifier.
7 |
8 | Parameters
9 | ----------
10 | fv : ``Data`` object
11 | the feature vector must have 2 dimensional data, the first
12 | dimension being the class axis. The unique class labels must be
13 | 0 and 1 otherwise a ``ValueError`` will be raised.
14 | shrink : Boolean, optional
15 | use shrinkage
16 |
17 | Returns
18 | -------
19 | w : 1d array
20 | b : float
21 |
22 | Raises
23 | ------
24 | ValueError : if the class labels are not exactly 0s and 1s
25 |
26 | Examples
27 | --------
28 |
29 | >> clf = lda_train(fv_train)
30 | >> out = lda_apply(fv_test, clf)
31 |
32 | See Also
33 | --------
34 | lda_apply
35 |
36 | """
37 | assert shrink is True
38 | assert fv.X.ndim == 2
39 | x = fv.X
40 | y = fv.y
41 | if len(np.unique(y)) != 2:
42 | raise ValueError(
43 | 'Should only have two unique class labels, instead got'
44 | ': {labels}'.format(labels=np.unique(y)))
45 | # Use sorted labels
46 | labels = np.sort(np.unique(y))
47 | mu1 = np.mean(x[y == labels[0]], axis=0)
48 | mu2 = np.mean(x[y == labels[1]], axis=0)
49 | # x' = x - m
50 | m = np.empty(x.shape)
51 | m[y == labels[0]] = mu1
52 | m[y == labels[1]] = mu2
53 | x2 = x - m
54 | # w = cov(x)^-1(mu2 - mu1)
55 | if shrink:
56 | estimator = LedoitWolf()
57 | covm = estimator.fit(x2).covariance_
58 | else:
59 | covm = np.cov(x2.T)
60 | w = np.dot(np.linalg.pinv(covm), (mu2 - mu1))
61 |
62 | # From MATLAB bbci toolbox:
63 | # https://github.com/bbci/bbci_public/blob/
64 | # fe6caeb549fdc864a5accf76ce71dd2a926ff12b/classification/
65 | # train_RLDAshrink.m#L133-L134
66 | # C.w= C.w/(C.w'*diff(C_mean, 1, 2))*2;
67 | # C.b= -C.w' * mean(C_mean,2);
68 | w = (w / np.dot(w.T, (mu2 - mu1))) * 2
69 | b = np.dot(-w.T, np.mean((mu1, mu2), axis=0))
70 | assert not np.any(np.isnan(w))
71 | assert not np.isnan(b)
72 | return w, b
73 |
74 |
75 | def lda_apply(fv, clf):
76 | """Apply feature vector to LDA classifier.
77 |
78 | Parameters
79 | ----------
80 | fv : ``Data`` object
81 | the feature vector must have a 2 dimensional data, the first
82 | dimension being the class axis.
83 | clf : (1d array, float)
84 |
85 | Returns
86 | -------
87 |
88 | out : 1d array
89 | The projection of the data on the hyperplane.
90 |
91 | Examples
92 | --------
93 |
94 | >> clf = lda_train(fv_train)
95 | >> out = lda_apply(fv_test, clf)
96 |
97 |
98 | See Also
99 | --------
100 | lda_train
101 |
102 | """
103 | x = fv.X
104 | w, b = clf
105 | return np.dot(x, w) + b
106 |
--------------------------------------------------------------------------------
/hgdecode/loaders.py:
--------------------------------------------------------------------------------
1 | import logging as log
2 | from copy import deepcopy
3 | from numpy import max
4 | from numpy import abs
5 | from numpy import sum
6 | from numpy import std
7 | from numpy import mean
8 | from numpy import array
9 | from numpy import floor
10 | from numpy import repeat
11 | from numpy import arange
12 | from numpy import setdiff1d
13 | from numpy import concatenate
14 | from numpy import count_nonzero
15 | from numpy.random import RandomState
16 | from os.path import join
17 | from mne.io.array.array import RawArray
18 | from hgdecode.utils import print_manager
19 | from hgdecode.classes import CrossValidation
20 | from sklearn.model_selection import StratifiedKFold
21 | from braindecode.datasets.bbci import BBCIDataset
22 | from braindecode.mne_ext.signalproc import mne_apply
23 | from braindecode.mne_ext.signalproc import resample_cnt
24 | from braindecode.mne_ext.signalproc import concatenate_raws_with_events
25 | from braindecode.datautil.signalproc import bandpass_cnt
26 | from braindecode.datautil.signalproc import exponential_running_standardize
27 | from braindecode.datautil.signal_target import SignalAndTarget
28 | from braindecode.datautil.trial_segment import \
29 | create_signal_target_from_raw_mne
30 |
31 |
32 | # TODO: re-implement all this functions as an unique class
33 |
34 |
35 | def get_data_files_paths(data_dir, subject_id=1, train_test_split=True):
36 | # compute file name (for both train and test path)
37 | file_name = '{:d}.mat'.format(subject_id)
38 |
39 | # compute file paths
40 | if train_test_split:
41 | train_file_path = join(data_dir, 'train', file_name)
42 | test_file_path = join(data_dir, 'test', file_name)
43 | file_path = [train_file_path, test_file_path]
44 | else:
45 | file_path = [join(data_dir, 'train', file_name)]
46 |
47 | # return paths
48 | return file_path
49 |
50 |
51 | def load_cnt(file_path, channel_names, clean_on_all_channels=True):
52 | # if we have to run the cleaning procedure on all channels, putting
53 | # load_sensor_names to None will assure us the BBCIDataset class will
54 | # load all possible sensors
55 | if clean_on_all_channels is True:
56 | channel_names = None
57 |
58 | # create the loader object for BBCI standard
59 | loader = BBCIDataset(file_path, load_sensor_names=channel_names)
60 |
61 | # load data
62 | return loader.load()
63 |
64 |
65 | def get_clean_trial_mask(cnt, name_to_start_codes, clean_ival_ms=(0, 4000)):
66 | """
67 | Scan trial in continuous data and create a mask with only the
68 | valid ones; in this way, at the and of the loading routine,
69 | after all the data pre-processing, you will be able to cut away
70 | the original not valid data.
71 | """
72 | # split cnt into trials data for cleaning
73 | set_for_cleaning = create_signal_target_from_raw_mne(
74 | cnt,
75 | name_to_start_codes,
76 | clean_ival_ms
77 | )
78 |
79 | # compute the clean_trial_mask: in this case we take only all
80 | # trials that have absolute microvolt values larger than +- 800
81 | clean_trial_mask = max(abs(set_for_cleaning.X), axis=(1, 2)) < 800
82 |
83 | # logging clean trials information
84 | log.info(
85 | 'Clean trials: {:3d} of {:3d} ({:5.1f}%)'.format(
86 | sum(clean_trial_mask),
87 | len(set_for_cleaning.X),
88 | mean(clean_trial_mask) * 100)
89 | )
90 |
91 | # return the clean_trial_mask
92 | return clean_trial_mask
93 |
94 |
95 | def pick_right_channels(cnt, channel_names):
96 | # return the same cnt but with only right channels
97 | return cnt.pick_channels(channel_names)
98 |
99 |
100 | def standardize_cnt(cnt, standardize_mode=0):
101 | # computing frequencies
102 | sampling_freq = cnt.info['sfreq']
103 | init_freq = 0.1
104 | stop_freq = sampling_freq / 2 - 0.1
105 | filt_order = 3
106 | axis = 0
107 | filtfilt = False
108 |
109 | # filtering DC and frequencies higher than the nyquist one
110 | cnt = mne_apply(
111 | lambda x:
112 | bandpass_cnt(
113 | data=x,
114 | low_cut_hz=init_freq,
115 | high_cut_hz=stop_freq,
116 | fs=sampling_freq,
117 | filt_order=filt_order,
118 | axis=axis,
119 | filtfilt=filtfilt
120 | ),
121 | cnt
122 | )
123 |
124 | # removing mean and normalizing in 3 different ways
125 | if standardize_mode == 0:
126 | # x - mean
127 | cnt = mne_apply(
128 | lambda x:
129 | x - mean(x, axis=0, keepdims=True),
130 | cnt
131 | )
132 | elif standardize_mode == 1:
133 | # (x - mean) / std
134 | cnt = mne_apply(
135 | lambda x:
136 | (x - mean(x, axis=0, keepdims=True)) /
137 | std(x, axis=0, keepdims=True),
138 | cnt
139 | )
140 | elif standardize_mode == 2:
141 | # parsing to milli volt for numerical stability of next operations
142 | cnt = mne_apply(lambda a: a * 1e6, cnt)
143 |
144 | # applying exponential_running_standardize (Schirrmeister)
145 | cnt = mne_apply(
146 | lambda x:
147 | exponential_running_standardize(
148 | x.T,
149 | factor_new=1e-3,
150 | init_block_size=1000,
151 | eps=1e-4
152 | ).T,
153 | cnt
154 | )
155 | return cnt
156 |
157 |
158 | def load_and_preprocess_data(data_dir,
159 | name_to_start_codes,
160 | channel_names,
161 | subject_id=1,
162 | resampling_freq=None,
163 | clean_ival_ms=(0, 4000),
164 | train_test_split=True,
165 | clean_on_all_channels=True,
166 | standardize_mode=None):
167 | # TODO: create here another get_data_files_paths function if you have a
168 | # different file configuration; in every case, file_paths must be a
169 | # list of paths to valid BBCI standard files
170 | # getting data paths
171 | file_paths = get_data_files_paths(
172 | data_dir,
173 | subject_id=subject_id,
174 | train_test_split=train_test_split
175 | )
176 |
177 | # starting the loading routine
178 | print_manager('DATA LOADING ROUTINE FOR SUBJ ' + str(subject_id),
179 | 'double-dashed')
180 | print_manager('Loading continuous data...')
181 |
182 | # pre-allocating main cnt
183 | cnt = None
184 |
185 | # loading files and merging them
186 | for idx, current_path in enumerate(file_paths):
187 | current_cnt = load_cnt(file_path=current_path,
188 | channel_names=channel_names,
189 | clean_on_all_channels=clean_on_all_channels)
190 | # if the path is the first one...
191 | if idx is 0:
192 | # ...copying current_cnt as the main one, else...
193 | cnt = deepcopy(current_cnt)
194 | else:
195 | # merging current_cnt with the main one
196 | cnt = concatenate_raws_with_events([cnt, current_cnt])
197 | print_manager('DONE!!', bottom_return=1)
198 |
199 | # getting clean_trial_mask
200 | print_manager('Getting clean trial mask...')
201 | clean_trial_mask = get_clean_trial_mask(
202 | cnt=cnt,
203 | name_to_start_codes=name_to_start_codes,
204 | clean_ival_ms=clean_ival_ms
205 | )
206 | print_manager('DONE!!', bottom_return=1)
207 |
208 | # pick only right channels
209 | log.info('Picking only right channels...')
210 | cnt = pick_right_channels(cnt, channel_names)
211 | print_manager('DONE!!', bottom_return=1)
212 |
213 | # resample continuous data
214 | if resampling_freq is not None:
215 | log.info('Resampling continuous data...')
216 | cnt = resample_cnt(
217 | cnt,
218 | resampling_freq
219 | )
220 | print_manager('DONE!!', bottom_return=1)
221 |
222 | # standardize continuous data
223 | if standardize_mode is not None:
224 | log.info('Standardizing continuous data...')
225 | log.info('Standardize mode: {}'.format(standardize_mode))
226 | cnt = standardize_cnt(cnt=cnt, standardize_mode=standardize_mode)
227 | print_manager('DONE!!', 'last', bottom_return=1)
228 |
229 | return cnt, clean_trial_mask
230 |
231 |
232 | def ml_loader(data_dir,
233 | name_to_start_codes,
234 | channel_names,
235 | subject_id=1,
236 | resampling_freq=None,
237 | clean_ival_ms=(0, 4000),
238 | train_test_split=True,
239 | clean_on_all_channels=True,
240 | standardize_mode=None):
241 | outputs = load_and_preprocess_data(
242 | data_dir=data_dir,
243 | name_to_start_codes=name_to_start_codes,
244 | channel_names=channel_names,
245 | subject_id=subject_id,
246 | resampling_freq=resampling_freq,
247 | clean_ival_ms=clean_ival_ms,
248 | train_test_split=train_test_split,
249 | clean_on_all_channels=clean_on_all_channels,
250 | standardize_mode=standardize_mode
251 | )
252 | return outputs[0], outputs[1]
253 |
254 |
255 | def dl_loader(data_dir,
256 | name_to_start_codes,
257 | channel_names,
258 | subject_id=1,
259 | resampling_freq=None,
260 | clean_ival_ms=(0, 4000),
261 | epoch_ival_ms=(-500, 4000),
262 | train_test_split=True,
263 | clean_on_all_channels=True,
264 | standardize_mode=0):
265 | # loading and pre-processing data
266 | cnt, clean_trial_mask = load_and_preprocess_data(
267 | data_dir=data_dir,
268 | name_to_start_codes=name_to_start_codes,
269 | channel_names=channel_names,
270 | subject_id=subject_id,
271 | resampling_freq=resampling_freq,
272 | clean_ival_ms=clean_ival_ms,
273 | train_test_split=train_test_split,
274 | clean_on_all_channels=clean_on_all_channels,
275 | standardize_mode=standardize_mode
276 | )
277 | print_manager('EPOCHING AND CLEANING WITH MASK', 'double-dashed')
278 |
279 | # epoching continuous data (from RawArray to SignalAndTarget)
280 | print_manager('Epoching...')
281 | epo = create_signal_target_from_raw_mne(
282 | cnt,
283 | name_to_start_codes,
284 | epoch_ival_ms
285 | )
286 | print_manager('DONE!!', bottom_return=1)
287 |
288 | # cleaning epoched signal with mask
289 | print_manager('cleaning with mask...')
290 | epo.X = epo.X[clean_trial_mask]
291 | epo.y = epo.y[clean_trial_mask]
292 | print_manager('DONE!!', 'last', bottom_return=1)
293 |
294 | # returning only the epoched signal
295 | return epo
296 |
297 |
298 | class CrossSubject(object):
299 | """
300 | Nel momento in cui si crea una istanza di questa classe, essa caricherà
301 | tutti quanti gli id soggetto indicati in formato cnt con le relative
302 | maschere ed un nuovo array che ci dice tutti gli indici in cui iniziano
303 | i vari soggetti; inoltre, si andrà poi a specificare volta per volta il
304 | tipo di dato che si vorrà avere in memoria utilizzando un metodo parser,
305 | che andrà a sovrascrivere i dati nel nuovo formato.
306 | """
307 |
308 | def __init__(self,
309 | data_dir,
310 | subject_ids,
311 | channel_names,
312 | name_to_start_codes,
313 | random_state=None,
314 | validation_frac=None,
315 | validation_size=None,
316 | resampling_freq=None,
317 | train_test_split=True,
318 | clean_ival_ms=(-500, 4000),
319 | epoch_ival_ms=(-500, 4000),
320 | clean_on_all_channels=True):
321 | # from input properties
322 | self.data_dir = data_dir
323 | self.subject_ids = subject_ids
324 | self.channel_names = channel_names
325 | self.name_to_start_codes = name_to_start_codes
326 | self.resampling_freq = resampling_freq
327 | self.train_test_split = train_test_split
328 | self.clean_ival_ms = clean_ival_ms
329 | self.epoch_ival_ms = epoch_ival_ms
330 | self.clean_on_all_channels = clean_on_all_channels
331 |
332 | # saving random state; if it is not specified, creating a 1234 one
333 | if random_state is None:
334 | self.random_state = RandomState(1234)
335 | else:
336 | self.random_state = random_state
337 |
338 | # other object properties
339 | self.data = None
340 | self.clean_trial_mask = []
341 | self.subject_labels = None
342 | self.subject_indexes = []
343 | self.folds = None
344 |
345 | # for fold specific data, creating a blank property
346 | self.fold_data = None
347 | self.fold_subject_labels = None
348 |
349 | # loading the first subject (to pre-allocate cnt array)
350 | # we are gonna pass standardize_mode=None, so the loading procedure
351 | # will not standardize data. They will be standardized at the very
352 | # end, when all data are loaded
353 | temp_cnt, temp_mask = load_and_preprocess_data(
354 | data_dir=self.data_dir,
355 | name_to_start_codes=self.name_to_start_codes,
356 | channel_names=self.channel_names,
357 | subject_id=self.subject_ids[0],
358 | resampling_freq=self.resampling_freq,
359 | clean_ival_ms=self.clean_ival_ms,
360 | train_test_split=self.train_test_split,
361 | clean_on_all_channels=self.clean_on_all_channels,
362 | standardize_mode=None
363 | )
364 |
365 | # allocate the first subject_labels
366 | temp_labels = repeat(array([subject_ids[0]]), len(temp_mask))
367 |
368 | # appending new indexes (only cleaned cnt will count!)
369 | last_non_zero_len = count_nonzero(self.clean_trial_mask)
370 | self.subject_indexes.append(
371 | [last_non_zero_len, last_non_zero_len + count_nonzero(temp_mask)]
372 | )
373 |
374 | # merging cnt, mask and labels (in this case assigning)
375 | self.data = temp_cnt
376 | self.clean_trial_mask = temp_mask
377 | self.subject_labels = temp_labels
378 |
379 | # creating iterable object from subject_ids and skipping the first one
380 | iter_subjects = iter(self.subject_ids)
381 | next(iter_subjects)
382 |
383 | # loading all others cnt data and concatenating them
384 | for current_subject in iter_subjects:
385 | # loading current subject cnt and mask
386 | temp_cnt, temp_mask = load_and_preprocess_data(
387 | subject_id=current_subject, # here the current subject!
388 | data_dir=self.data_dir,
389 | name_to_start_codes=self.name_to_start_codes,
390 | channel_names=self.channel_names,
391 | resampling_freq=self.resampling_freq,
392 | clean_ival_ms=self.clean_ival_ms,
393 | train_test_split=self.train_test_split,
394 | clean_on_all_channels=self.clean_on_all_channels,
395 | standardize_mode=None
396 | )
397 |
398 | # create the subject_labels for this subject
399 | temp_labels = repeat(array([current_subject]), len(temp_mask))
400 |
401 | # appending new indexes (only cleaned cnt will count!)
402 | last_non_zero_len = count_nonzero(self.clean_trial_mask)
403 | self.subject_indexes.append(
404 | [last_non_zero_len,
405 | last_non_zero_len + count_nonzero(temp_mask)]
406 | )
407 |
408 | # merging cnt and mask
409 | self.data = concatenate_raws_with_events([self.data, temp_cnt])
410 | self.clean_trial_mask = \
411 | concatenate([self.clean_trial_mask, temp_mask])
412 | self.subject_labels = \
413 | concatenate([self.subject_labels, temp_labels])
414 |
415 | # computing validation_frac and validation_size
416 | if validation_size is None:
417 | if validation_frac is None:
418 | self.validation_frac = 0
419 | self.validation_size = 0
420 | else:
421 | self.validation_frac = validation_frac
422 | self.validation_size = \
423 | int(floor(self.n_trials * self.validation_frac))
424 | else:
425 | self.validation_size = validation_size
426 | self.validation_frac = self.validation_size / self.n_trials
427 |
428 | def parser(self, output_format, leave_subj=None, parsing_type=0):
429 | """
430 | HOW DOES IT WORK?
431 | -----------------
432 | if parsing_type is 0 then epoched signal will be saved in
433 | fold_data; if parsing_type is 1 then the cnt signal will be replaced
434 | with the epoched one
435 | """
436 | if output_format is 'epo':
437 | self.cnt_to_epo(parsing_type=parsing_type)
438 | elif output_format is 'EEGDataset':
439 | self.cnt_to_epo(parsing_type=parsing_type)
440 | self.epo_to_dataset(leave_subj=leave_subj,
441 | parsing_type=parsing_type)
442 |
443 | def cnt_to_epo(self, parsing_type):
444 | # checking if data is cnt; if not, the method will not work
445 | if isinstance(self.data, RawArray):
446 | """
447 | WHATS GOING ON HERE?
448 | --------------------
449 | If parsing_type is 0, then there will be a 'soft parsing
450 | routine', data will parsed and stored in fold_data instead of
451 | in the main data property
452 | """
453 | if parsing_type == 0:
454 | # parsing from cnt to epoch
455 | print_manager('Parsing cnt signal to epoched one...')
456 | self.fold_data = create_signal_target_from_raw_mne(
457 | self.data,
458 | self.name_to_start_codes,
459 | self.epoch_ival_ms
460 | )
461 | print_manager('DONE!!', bottom_return=1)
462 |
463 | # cleaning signal and labels with mask
464 | print_manager('Cleaning epoched signal with mask...')
465 | self.fold_data.X = self.fold_data.X[self.clean_trial_mask]
466 | self.fold_data.y = self.fold_data.y[self.clean_trial_mask]
467 | self.fold_subject_labels = \
468 | self.subject_labels[self.clean_trial_mask]
469 | print_manager('DONE!!', bottom_return=1)
470 | elif parsing_type == 1:
471 | """
472 | WHATS GOING ON HERE?
473 | --------------------
474 | If parsing_type is 1, then the epoched signal will replace
475 | the original one in the data property
476 | """
477 | print_manager('Parsing cnt signal to epoched one...')
478 | self.data = create_signal_target_from_raw_mne(
479 | self.data,
480 | self.name_to_start_codes,
481 | self.epoch_ival_ms
482 | )
483 | print_manager('DONE!!', bottom_return=1)
484 |
485 | # cleaning signal and labels
486 | print_manager('Cleaning epoched signal with mask...')
487 | self.data.X = self.data.X[self.clean_trial_mask]
488 | self.data.y = self.data.y[self.clean_trial_mask]
489 | self.subject_labels = \
490 | self.subject_labels[self.clean_trial_mask]
491 | print_manager('DONE!!', bottom_return=1)
492 | else:
493 | raise ValueError(
494 | 'parsing_type {} not supported.'.format(parsing_type)
495 | )
496 |
497 | # now that we have an epoched signal, we can already create
498 | # folds for cross-subject validation
499 | self.create_balanced_folds()
500 |
501 | def epo_to_dataset(self, leave_subj, parsing_type=0):
502 | print_manager('FOLD ALL BUT ' + str(leave_subj), 'double-dashed')
503 | print_manager('Creating current fold...')
504 |
505 | print_manager('DONE!!', bottom_return=1)
506 | print_manager('Parsing epoched signal to EEGDataset...')
507 | if parsing_type is 0:
508 | self.fold_data = CrossValidation.create_dataset_static(
509 | self.fold_data, self.folds[leave_subj - 1]
510 | )
511 | elif parsing_type is 1:
512 | self.fold_data = CrossValidation.create_dataset_static(
513 | self.data, self.folds[leave_subj - 1]
514 | )
515 | else:
516 | raise ValueError(
517 | 'parsing_type {} not supported.'.format(parsing_type)
518 | )
519 | print_manager('DONE!!', bottom_return=1)
520 | print_manager('We obtained a ' + str(self.fold_data))
521 | print_manager('DATA READY!!', 'last', bottom_return=1)
522 |
523 | def create_balanced_folds(self):
524 | # pre-allocating folds
525 | self.folds = []
526 | for subj_idx, subj_idxs in enumerate(self.subject_indexes):
527 | # getting current test_idxs (all a subject trials)
528 | test_idxs = arange(subj_idxs[0], subj_idxs[1])
529 |
530 | # getting train_idxs as all but the current subject
531 | train_idxs = setdiff1d(arange(self.n_trials), test_idxs)
532 |
533 | # pre-allocating valid_idxs
534 | valid_idxs = array([], dtype='int')
535 |
536 | # if no validation set is required...
537 | if self.validation_frac == 0:
538 | # setting valid_idxs to None, else...
539 | valid_idxs = None
540 | else:
541 | # ...determining number of splits for this train/validation set
542 | n_splits = int(floor(self.validation_frac * 100))
543 |
544 | # getting StratifiesKFold object
545 | skf = StratifiedKFold(n_splits=n_splits,
546 | random_state=self.random_state,
547 | shuffle=True)
548 |
549 | # cycling on subject in the train fold
550 | for c_subj_idx, c_subj_idxs in enumerate(self.subject_indexes):
551 | if c_subj_idx == subj_idx:
552 | # nothing to do
553 | pass
554 | else:
555 | # splitting first subject train / valid
556 | X, y = self._get_subject_data(c_subj_idx)
557 |
558 | # get batch from StratifiedKFold object
559 | for c_train_idxs, c_valid_idxs in skf.split(X=X, y=y):
560 | # referring c_train_idxs and c_valid_idxs
561 | c_train_idxs += c_subj_idxs[0]
562 | c_valid_idxs += c_subj_idxs[0]
563 |
564 | # remove this batch indexes from train_idxs
565 | train_idxs = setdiff1d(train_idxs, c_valid_idxs)
566 |
567 | # adding this batch indexes to valid_idxs
568 | valid_idxs = concatenate([valid_idxs,
569 | c_valid_idxs])
570 | # all is done for this subject!! Breaking cycle
571 | break
572 |
573 | # appending new fold
574 | self.folds.append(
575 | {
576 | 'train': train_idxs,
577 | 'valid': valid_idxs,
578 | 'test': test_idxs
579 | }
580 | )
581 |
582 | def _get_subject_data(self, subj_idx):
583 | init = self.subject_indexes[subj_idx][0]
584 | stop = self.subject_indexes[subj_idx][1]
585 | ival = arange(init, stop)
586 | if isinstance(self.fold_data, SignalAndTarget):
587 | return self.fold_data.X[ival], self.fold_data.y[ival]
588 | elif isinstance(self.data, SignalAndTarget):
589 | return self.data.X[ival], self.data.y[ival]
590 | else:
591 | raise ValueError('You are trying to get epoched data but you '
592 | 'still have to parse cnt data.')
593 |
594 | @property
595 | def n_trials(self):
596 | return count_nonzero(self.clean_trial_mask)
597 |
--------------------------------------------------------------------------------
/hgdecode/models.py:
--------------------------------------------------------------------------------
1 | from keras import backend as K
2 | from keras.models import Model
3 | from keras.layers import Dense
4 | from keras.layers import Input
5 | from keras.layers import Conv2D
6 | from keras.layers import Permute
7 | from keras.layers import Dropout
8 | from keras.layers import Flatten
9 | from keras.layers import Activation
10 | from keras.layers import MaxPooling2D
11 | from keras.layers import SeparableConv2D
12 | from keras.layers import DepthwiseConv2D
13 | from keras.layers import SpatialDropout2D
14 | from keras.layers import AveragePooling2D
15 | from keras.layers import BatchNormalization
16 | from keras.constraints import max_norm
17 | from keras.regularizers import l1_l2
18 |
19 |
20 | # TODO: define pool_size, strides and other parameters as experiment
21 | # properties, tunable from the user in the main script; then passing the
22 | # entire experiment to the model function constructor or each useful
23 | # parameter individually
24 |
25 |
26 | # %% DEEP CONV NET
27 | def DeepConvNet(n_classes=4,
28 | n_channels=64,
29 | n_samples=256,
30 | dropout_rate=0.5):
31 | """ Keras implementation of the Deep Convolutional Network as described in
32 | Schirrmeister et. al. (2017), Human Brain Mapping.
33 | This implementation assumes the input is a 2-second EEG signal sampled at
34 | 128Hz, as opposed to signals sampled at 250Hz as described in the original
35 | paper. We also perform temporal convolutions of length (1, 5) as opposed
36 | to (1, 10) due to this sampling rate difference.
37 | Note that we use the max_norm constraint on all convolutional layers, as
38 | well as the classification layer. We also change the defaults for the
39 | BatchNormalization layer. We used this based on a personal communication
40 | with the original authors.
41 | ours original paper
42 | pool_size 1, 2 1, 3
43 | strides 1, 2 1, 3
44 | conv filters 1, 5 1, 10
45 | Note that this implementation has not been verified by the original
46 | authors.
47 | """
48 |
49 | # start the model
50 | input_main = Input((1, n_channels, n_samples))
51 | block1 = Conv2D(25, (1, 10),
52 | input_shape=(1, n_channels, n_samples),
53 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(input_main)
54 | block1 = Conv2D(25, (n_channels, 1),
55 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(block1)
56 | block1 = BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)(block1)
57 | block1 = Activation('elu')(block1)
58 | block1 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1)
59 | block1 = Dropout(dropout_rate)(block1)
60 |
61 | block2 = Conv2D(50, (1, 10),
62 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(block1)
63 | block2 = BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)(block2)
64 | block2 = Activation('elu')(block2)
65 | block2 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block2)
66 | block2 = Dropout(dropout_rate)(block2)
67 |
68 | block3 = Conv2D(100, (1, 10),
69 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(block2)
70 | block3 = BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)(block3)
71 | block3 = Activation('elu')(block3)
72 | block3 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block3)
73 | block3 = Dropout(dropout_rate)(block3)
74 |
75 | block4 = Conv2D(200, (1, 10),
76 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(block3)
77 | block4 = BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)(block4)
78 | block4 = Activation('elu')(block4)
79 | block4 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block4)
80 | block4 = Dropout(dropout_rate)(block4)
81 |
82 | flatten = Flatten()(block4)
83 |
84 | dense = Dense(n_classes, kernel_constraint=max_norm(0.5))(flatten)
85 | softmax = Activation('softmax')(dense)
86 |
87 | return Model(inputs=input_main, outputs=softmax)
88 |
89 |
90 | # %% DEEP CONV NET 500 Hz
91 | def DeepConvNet_500Hz(n_classes=4,
92 | n_channels=64,
93 | n_samples=256,
94 | dropout_rate=0.5):
95 | """
96 | # TODO: description for this model
97 | """
98 | # input
99 | input_main = Input((1, n_channels, n_samples))
100 |
101 | # block1
102 | block1 = Conv2D(25, (1, 20),
103 | # bias_initializer='truncated_normal',
104 | # kernel_initializer='he_normal',
105 | # kernel_regularizer=l2(0.0001),
106 | input_shape=(1, n_channels, n_samples),
107 | kernel_constraint=max_norm(2., axis=(0, 1, 2))
108 | )(input_main)
109 | block1 = Conv2D(25, (n_channels, 1),
110 | # bias_initializer='truncated_normal',
111 | # kernel_initializer='he_normal',
112 | # kernel_regularizer=l2(0.0001),
113 | kernel_constraint=max_norm(2., axis=(0, 1, 2))
114 | )(block1)
115 | block1 = BatchNormalization(axis=1)(block1)
116 | block1 = Activation('elu')(block1)
117 | block1 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block1)
118 | block1 = Dropout(dropout_rate)(block1)
119 |
120 | # block2
121 | block2 = Conv2D(50, (1, 20),
122 | # bias_initializer='truncated_normal',
123 | # kernel_initializer='he_normal',
124 | # kernel_regularizer=l2(0.0001),
125 | kernel_constraint=max_norm(2., axis=(0, 1, 2))
126 | )(block1)
127 | block2 = BatchNormalization(axis=1)(block2)
128 | block2 = Activation('elu')(block2)
129 | block2 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block2)
130 | block2 = Dropout(dropout_rate)(block2)
131 |
132 | # block3
133 | block3 = Conv2D(100, (1, 20),
134 | # bias_initializer='truncated_normal',
135 | # kernel_initializer='he_normal',
136 | # kernel_regularizer=l2(0.0001),
137 | kernel_constraint=max_norm(2., axis=(0, 1, 2))
138 | )(block2)
139 | block3 = BatchNormalization(axis=1)(block3)
140 | block3 = Activation('elu')(block3)
141 | block3 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block3)
142 | block3 = Dropout(dropout_rate)(block3)
143 |
144 | # block4
145 | block4 = Conv2D(200, (1, 20),
146 | # bias_initializer='truncated_normal',
147 | # kernel_initializer='he_normal',
148 | # kernel_regularizer=l2(0.0001),
149 | kernel_constraint=max_norm(2., axis=(0, 1, 2))
150 | )(block3)
151 | block4 = BatchNormalization(axis=1)(block4)
152 | block4 = Activation('elu')(block4)
153 | block4 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block4)
154 | block4 = Dropout(dropout_rate)(block4)
155 |
156 | # flatten
157 | flatten = Flatten()(block4)
158 |
159 | # another dense one
160 | # dense = Dense(128, bias_initializer='truncated_normal',
161 | # kernel_initializer='he_normal',
162 | # kernel_regularizer=l2(0.001),
163 | # kernel_constraint=max_norm(0.5))(flatten)
164 | # dense = Activation('elu')(dense)
165 | # dense = Dropout(dropout_rate)(dense)
166 |
167 | # dense
168 | dense = Dense(n_classes,
169 | # bias_initializer='truncated_normal',
170 | # kernel_initializer='truncated_normal',
171 | kernel_constraint=max_norm(0.5)
172 | )(flatten)
173 | softmax = Activation('softmax')(dense)
174 |
175 | # returning the model
176 | return Model(inputs=input_main, outputs=softmax)
177 |
178 |
179 | # %% DEEP CONV NET DAVIDE
180 | def DeepConvNet_Davide(n_classes=4,
181 | n_channels=64,
182 | n_samples=256,
183 | dropout_rate=0.5):
184 | """
185 | TODO: a description for this model
186 | :param n_classes:
187 | :param n_channels:
188 | :param n_samples:
189 | :param dropout_rate:
190 | :return:
191 | """
192 | # start the model
193 | input_main = Input((1, n_channels, n_samples))
194 | block1 = Conv2D(25, (1, 10),
195 | # bias_initializer='truncated_normal',
196 | # kernel_initializer='he_normal',
197 | # kernel_regularizer=l2(0.0001),
198 | input_shape=(1, n_channels, n_samples),
199 | kernel_constraint=max_norm(2., axis=(0, 1, 2))
200 | )(input_main)
201 | block1 = Conv2D(25, (n_channels, 1),
202 | # bias_initializer='truncated_normal',
203 | # kernel_initializer='he_normal',
204 | # kernel_regularizer=l2(0.0001),
205 | kernel_constraint=max_norm(2., axis=(0, 1, 2))
206 | )(block1)
207 | block1 = BatchNormalization(axis=1)(block1)
208 | block1 = Activation('elu')(block1)
209 |
210 | block1 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1)
211 | block1 = Dropout(dropout_rate)(block1)
212 |
213 | block2 = Conv2D(50, (1, 10),
214 | # bias_initializer='truncated_normal',
215 | # kernel_initializer='he_normal',
216 | # kernel_regularizer=l2(0.0001),
217 | kernel_constraint=max_norm(2., axis=(0, 1, 2))
218 | )(block1)
219 | block2 = BatchNormalization(axis=1)(block2)
220 | block2 = Activation('elu')(block2)
221 | block2 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block2)
222 | block2 = Dropout(dropout_rate)(block2)
223 |
224 | block3 = Conv2D(100, (1, 10),
225 | # bias_initializer='truncated_normal',
226 | # kernel_initializer='he_normal',
227 | # kernel_regularizer=l2(0.0001),
228 | kernel_constraint=max_norm(2., axis=(0, 1, 2))
229 | )(block2)
230 | block3 = BatchNormalization(axis=1)(block3)
231 | block3 = Activation('elu')(block3)
232 |
233 | block3 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block3)
234 | block3 = Dropout(dropout_rate)(block3)
235 |
236 | block4 = Conv2D(200, (1, 10),
237 | # bias_initializer='truncated_normal',
238 | # kernel_initializer='he_normal',
239 | # kernel_regularizer=l2(0.0001),
240 | kernel_constraint=max_norm(2., axis=(0, 1, 2))
241 | )(block3)
242 | block4 = BatchNormalization(axis=1)(block4)
243 | block4 = Activation('elu')(block4)
244 |
245 | block4 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block4)
246 | block4 = Dropout(dropout_rate)(block4)
247 |
248 | flatten = Flatten()(block4)
249 | # dense = Dense(128, bias_initializer='truncated_normal',
250 | # kernel_initializer='he_normal',
251 | # kernel_regularizer=l2(0.001),
252 | # kernel_constraint=max_norm(0.5))(flatten)
253 | # dense = Activation('elu')(dense)
254 | # dense = Dropout(dropout_rate)(dense)
255 | dense = Dense(n_classes,
256 | # bias_initializer='truncated_normal',
257 | # kernel_initializer='truncated_normal',
258 | kernel_constraint=max_norm(0.5)
259 | )(flatten)
260 | softmax = Activation('softmax')(dense)
261 |
262 | return Model(inputs=input_main, outputs=softmax)
263 |
264 |
265 | # %% SHALLOW CONV NET
266 | def square(x):
267 | return K.square(x)
268 |
269 |
270 | def log(x):
271 | return K.log(K.clip(x, min_value=1e-7, max_value=10000))
272 |
273 |
274 | def ShallowConvNet(n_classes,
275 | n_channels=64,
276 | n_samples=128,
277 | dropout_rate=0.5):
278 | """ Keras implementation of the Shallow Convolutional Network as described
279 | in Schirrmeister et. al. (2017), Human Brain Mapping.
280 |
281 | Assumes the input is a 2-second EEG signal sampled at 128Hz. Note that in
282 | the original paper, they do temporal convolutions of length 25 for EEG
283 | data sampled at 250Hz. We instead use length 13 since the sampling rate is
284 | roughly half of the 250Hz which the paper used. The pool_size and stride
285 | in later layers is also approximately half of what is used in the paper.
286 |
287 | Note that we use the max_norm constraint on all convolutional layers, as
288 | well as the classification layer. We also change the defaults for the
289 | BatchNormalization layer. We used this based on a personal communication
290 | with the original authors.
291 |
292 | ours original paper
293 | pool_size 1, 35 1, 75
294 | strides 1, 7 1, 15
295 | conv filters 1, 13 1, 25
296 |
297 | Note that this implementation has not been verified by the original
298 | authors. We do note that this implementation reproduces the results in the
299 | original paper with minor deviations.
300 | """
301 |
302 | # start the model
303 | input_main = Input((1, n_channels, n_samples))
304 | block1 = Conv2D(40, (1, 13),
305 | input_shape=(1, n_channels, n_samples),
306 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(input_main)
307 | block1 = Conv2D(40, (n_channels, 1), use_bias=False,
308 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(block1)
309 | block1 = BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)(block1)
310 | block1 = Activation(square)(block1)
311 | block1 = AveragePooling2D(pool_size=(1, 35), strides=(1, 7))(block1)
312 | block1 = Activation(log)(block1)
313 | block1 = Dropout(dropout_rate)(block1)
314 | flatten = Flatten()(block1)
315 | dense = Dense(n_classes, kernel_constraint=max_norm(0.5))(flatten)
316 | softmax = Activation('softmax')(dense)
317 |
318 | return Model(inputs=input_main, outputs=softmax)
319 |
320 |
321 | # %% EEG NET
322 | def EEGNet(n_classes,
323 | n_channels=64,
324 | n_samples=128,
325 | dropout_rate=0.25,
326 | kernel_length=64,
327 | F1=4,
328 | D=2,
329 | F2=8,
330 | norm_rate=0.25,
331 | dropout_type='Dropout'):
332 | """ Keras Implementation of EEGNet
333 | http://iopscience.iop.org/article/10.1088/1741-2552/aace8c/meta
334 | Note that this implements the newest version of EEGNet and NOT the earlier
335 | version (version v1 and v2 on arxiv). We strongly recommend using this
336 | architecture as it performs much better and has nicer properties than
337 | our earlier version. For example:
338 |
339 | 1. Depthwise Convolutions to learn spatial filters within a
340 | temporal convolution. The use of the depth_multiplier option maps
341 | exactly to the number of spatial filters learned within a temporal
342 | filter. This matches the setup of algorithms like FBCSP which learn
343 | spatial filters within each filter in a filter-bank. This also limits
344 | the number of free parameters to fit when compared to a fully-connected
345 | convolution.
346 |
347 | 2. Separable Convolutions to learn how to optimally combine spatial
348 | filters across temporal bands. Separable Convolutions are Depthwise
349 | Convolutions followed by (1x1) Pointwise Convolutions.
350 |
351 |
352 | While the original paper used Dropout, we found that SpatialDropout2D
353 | sometimes produced slightly better results for classification of ERP
354 | signals. However, SpatialDropout2D significantly reduced performance
355 | on the Oscillatory dataset (SMR, BCI-IV Dataset 2A). We recommend using
356 | the default Dropout in most cases.
357 |
358 | Assumes the input signal is sampled at 128Hz. If you want to use this model
359 | for any other sampling rate you will need to modify the lengths of temporal
360 | kernels and average pooling size in blocks 1 and 2 as needed (double the
361 | kernel lengths for double the sampling rate, etc). Note that we haven't
362 | tested the model performance with this rule so this may not work well.
363 |
364 | The model with default parameters gives the EEGNet-4,2 model as discussed
365 | in the paper. This model should do pretty well in general, although as the
366 | paper discussed the EEGNet-8,2 (with 8 temporal kernels and 2 spatial
367 | filters per temporal kernel) can do slightly better on the SMR dataset.
368 | Other variations that we found to work well are EEGNet-4,1 and EEGNet-8,1.
369 | We set F2 = F1 * D (number of input filters = number of output filters) for
370 | the SeparableConv2D layer. We haven't extensively tested other values of
371 | this parameter (say, F2 < F1 * D for compressed learning, and F2 > F1 * D
372 | for overcomplete). We believe the main parameters to focus on are F1 and D.
373 | Inputs:
374 |
375 | n_classes : int, number of classes to classify
376 | n_channels : number of channels
377 | n_samples : number of time points in the EEG data
378 | dropout_rate : dropout fraction
379 | kernel_length : length of temporal convolution in first layer. We found
380 | that setting this to be half the sampling rate worked
381 | well in practice. For the SMR dataset in particular
382 | since the data was high-passed at 4Hz we used a kernel
383 | length of 32.
384 | F1, F2 : number of temporal filters (F1) and number of pointwise
385 | filters (F2) to learn. Default: F1 = 4, F2 = F1 * D.
386 | D : number of spatial filters to learn within each temporal
387 | convolution. Default: D = 2
388 | dropout_type : Either SpatialDropout2D or Dropout, passed as a string.
389 | """
390 |
391 | if dropout_type == 'SpatialDropout2D':
392 | dropout_type = SpatialDropout2D
393 | elif dropout_type == 'Dropout':
394 | dropout_type = Dropout
395 | else:
396 | raise ValueError('dropout_type must be one of SpatialDropout2D '
397 | 'or Dropout, passed as a string.')
398 |
399 | input1 = Input(shape=(1, n_channels, n_samples))
400 |
401 | ##################################################################
402 | block1 = Conv2D(F1, (1, kernel_length), padding='same',
403 | input_shape=(1, n_channels, n_samples),
404 | use_bias=False)(input1)
405 | block1 = BatchNormalization(axis=1)(block1)
406 | block1 = DepthwiseConv2D((n_channels, 1), use_bias=False,
407 | depth_multiplier=D,
408 | depthwise_constraint=max_norm(1.))(block1)
409 | block1 = BatchNormalization(axis=1)(block1)
410 | block1 = Activation('elu')(block1)
411 | block1 = AveragePooling2D((1, 4))(block1)
412 | block1 = dropout_type(dropout_rate)(block1)
413 |
414 | block2 = SeparableConv2D(F2, (1, 16),
415 | use_bias=False, padding='same')(block1)
416 | block2 = BatchNormalization(axis=1)(block2)
417 | block2 = Activation('elu')(block2)
418 | block2 = AveragePooling2D((1, 8))(block2)
419 | block2 = dropout_type(dropout_rate)(block2)
420 |
421 | flatten = Flatten(name='flatten')(block2)
422 |
423 | dense = Dense(n_classes, name='dense',
424 | kernel_constraint=max_norm(norm_rate))(flatten)
425 | softmax = Activation('softmax', name='softmax')(dense)
426 |
427 | return Model(inputs=input1, outputs=softmax)
428 |
429 |
430 | # %% EEG NET SSVEP
431 | def EEGNet_SSVEP(n_classes=12,
432 | n_channels=8,
433 | n_samples=256,
434 | dropout_rate=0.5,
435 | kernel_length=256,
436 | F1=96,
437 | D=1,
438 | F2=96,
439 | dropout_type='Dropout'):
440 | """ SSVEP Variant of EEGNet, as used in [1].
441 | Inputs:
442 |
443 | n_classes : int, number of classes to classify
444 | n_channels : number of channels
445 | n_samples : number of time points in the EEG data
446 | dropout_rate : dropout fraction
447 | kernel_length : length of temporal convolution in first layer
448 | F1, F2 : number of temporal filters (F1) and number of pointwise
449 | filters (F2) to learn.
450 | D : number of spatial filters to learn within each temporal
451 | convolution.
452 | dropout_type : Either SpatialDropout2D or Dropout, passed as a string.
453 |
454 |
455 | [1]. Waytowich, N. et. al. (2018). Compact Convolutional Neural Networks
456 | for Classification of Asynchronous Steady-State Visual Evoked Potentials.
457 | Journal of Neural Engineering vol. 15(6).
458 | http://iopscience.iop.org/article/10.1088/1741-2552/aae5d8
459 | """
460 |
461 | if dropout_type == 'SpatialDropout2D':
462 | dropout_type = SpatialDropout2D
463 | elif dropout_type == 'Dropout':
464 | dropout_type = Dropout
465 | else:
466 | raise ValueError('dropout_type must be one of SpatialDropout2D '
467 | 'or Dropout, passed as a string.')
468 |
469 | input1 = Input(shape=(1, n_channels, n_samples))
470 |
471 | ##################################################################
472 | block1 = Conv2D(F1, (1, kernel_length), padding='same',
473 | input_shape=(1, n_channels, n_samples),
474 | use_bias=False)(input1)
475 | block1 = BatchNormalization(axis=1)(block1)
476 | block1 = DepthwiseConv2D((n_channels, 1), use_bias=False,
477 | depth_multiplier=D,
478 | depthwise_constraint=max_norm(1.))(block1)
479 | block1 = BatchNormalization(axis=1)(block1)
480 | block1 = Activation('elu')(block1)
481 | block1 = AveragePooling2D((1, 4))(block1)
482 | block1 = dropout_type(dropout_rate)(block1)
483 |
484 | block2 = SeparableConv2D(F2, (1, 16),
485 | use_bias=False, padding='same')(block1)
486 | block2 = BatchNormalization(axis=1)(block2)
487 | block2 = Activation('elu')(block2)
488 | block2 = AveragePooling2D((1, 8))(block2)
489 | block2 = dropout_type(dropout_rate)(block2)
490 |
491 | flatten = Flatten(name='flatten')(block2)
492 |
493 | dense = Dense(n_classes, name='dense')(flatten)
494 | softmax = Activation('softmax', name='softmax')(dense)
495 |
496 | return Model(inputs=input1, outputs=softmax)
497 |
498 |
499 | # %% EEG NET OLD
500 | def EEGNet_old(n_classes,
501 | n_channels=64,
502 | n_samples=128,
503 | regRate=0.0001,
504 | dropout_rate=0.25,
505 | kernels=None,
506 | strides=(2, 4)):
507 | """ Keras Implementation of EEGNet_v1 (https://arxiv.org/abs/1611.08024v2)
508 | This model is the original EEGNet model proposed on arxiv
509 | https://arxiv.org/abs/1611.08024v2
510 |
511 | with a few modifications: we use striding instead of max-pooling as this
512 | helped slightly in classification performance while also providing a
513 | computational speed-up.
514 |
515 | Note that we no longer recommend the use of this architecture, as the new
516 | version of EEGNet performs much better overall and has nicer properties.
517 |
518 | Inputs:
519 |
520 | n_classes : total number of final categories
521 | n_channels : number of EEG channels
522 | n_samples : number of EEG time points
523 | regRate : regularization rate for L1 and L2 regularizations
524 | dropout_rate : dropout fraction
525 | kernels : the 2nd and 3rd layer kernel dimensions (default is
526 | the [2, 32] x [8, 4] configuration)
527 | strides : the stride size (note that this replaces the max-pool
528 | used in the original paper)
529 |
530 | """
531 | # fixing PEP8 mutable input issue
532 | if kernels is None:
533 | kernels = [(2, 32), (8, 4)]
534 |
535 | # start the model
536 | input_main = Input((1, n_channels, n_samples))
537 | layer1 = Conv2D(16, (n_channels, 1),
538 | input_shape=(1, n_channels, n_samples),
539 | kernel_regularizer=l1_l2(l1=regRate, l2=regRate))(
540 | input_main)
541 | layer1 = BatchNormalization(axis=1)(layer1)
542 | layer1 = Activation('elu')(layer1)
543 | layer1 = Dropout(dropout_rate)(layer1)
544 |
545 | permute_dims = 2, 1, 3
546 | permute1 = Permute(permute_dims)(layer1)
547 |
548 | layer2 = Conv2D(4, kernels[0], padding='same',
549 | kernel_regularizer=l1_l2(l1=0.0, l2=regRate),
550 | strides=strides)(permute1)
551 | layer2 = BatchNormalization(axis=1)(layer2)
552 | layer2 = Activation('elu')(layer2)
553 | layer2 = Dropout(dropout_rate)(layer2)
554 |
555 | layer3 = Conv2D(4, kernels[1], padding='same',
556 | kernel_regularizer=l1_l2(l1=0.0, l2=regRate),
557 | strides=strides)(layer2)
558 | layer3 = BatchNormalization(axis=1)(layer3)
559 | layer3 = Activation('elu')(layer3)
560 | layer3 = Dropout(dropout_rate)(layer3)
561 |
562 | flatten = Flatten(name='flatten')(layer3)
563 |
564 | dense = Dense(n_classes, name='dense')(flatten)
565 | softmax = Activation('softmax', name='softmax')(dense)
566 |
567 | return Model(inputs=input_main, outputs=softmax)
568 |
--------------------------------------------------------------------------------
/hgdecode/signalproc.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import numpy as np
4 | import scipy as sp
5 |
6 | from braindecode.datautil.signal_target import SignalAndTarget
7 | from braindecode.datautil.signalproc import bandpass_cnt
8 | from braindecode.mne_ext.signalproc import mne_apply
9 |
10 |
11 | def bandpass_mne(cnt, low_cut_hz, high_cut_hz, filt_order=3, axis=0):
12 | return mne_apply(lambda data: bandpass_cnt(
13 | data.T, low_cut_hz, high_cut_hz, fs=cnt.info['sfreq'],
14 | filt_order=filt_order, axis=axis).T,
15 | cnt)
16 |
17 |
18 | def select_trials(dataset, inds):
19 | if hasattr(dataset.X, 'ndim'):
20 | # numpy array
21 | new_X = np.array(dataset.X)[inds]
22 | else:
23 | # list
24 | new_X = [dataset.X[i] for i in inds]
25 | new_y = np.asarray(dataset.y)[inds]
26 | return SignalAndTarget(new_X, new_y)
27 |
28 |
29 | def select_classes_cnt(cnt, class_numbers):
30 | cnt = deepcopy(cnt)
31 | events = cnt.info['events']
32 | new_events = [ev for ev in events if
33 | (ev[2] - ev[1]) in class_numbers]
34 | cnt.info['events'] = np.array(new_events)
35 | return cnt
36 |
37 |
38 | def select_classes(dataset, class_numbers):
39 | wanted_inds = [i_trial for i_trial, y in enumerate(dataset.y)
40 | if y in class_numbers]
41 | return select_trials(dataset, wanted_inds)
42 |
43 |
44 | def select_trials_cnt(cnt, inds):
45 | cnt = deepcopy(cnt)
46 | assert np.all(
47 | [i in np.arange(len(cnt.info['events'])) for i in inds])
48 | events = cnt.info['events']
49 | new_events = [ev for i_trial, ev in enumerate(events) if
50 | i_trial in inds]
51 | cnt.info['events'] = np.array(new_events)
52 | return cnt
53 |
54 |
55 | def concatenate_channels(datasets):
56 | all_X = [dataset.X for dataset in datasets]
57 | new_X = np.concatenate(all_X, axis=1)
58 | new_y = datasets[0].y
59 | for dataset in datasets:
60 | assert np.array_equal(dataset.y, new_y)
61 | return SignalAndTarget(new_X, new_y)
62 |
63 |
64 | def extract_all_start_codes(name_to_start_codes):
65 | all_start_codes = []
66 | for val in name_to_start_codes.values():
67 | if hasattr(val, '__len__'):
68 | all_start_codes.extend(val)
69 | else:
70 | all_start_codes.append(val)
71 | return all_start_codes
72 |
73 |
74 | def calculate_csp(epo, classes=None, average_trial_covariance=False):
75 | """Calculate the Common Spatial Pattern (CSP) for two classes.
76 | Now with pattern computation as in matlab bbci toolbox
77 | https://github.com/bbci/bbci_public/blob/c7201e4e42f873cced2e068c6cbb3780a8f8e9ec/processing/proc_csp.m#L112
78 |
79 | This method calculates the CSP and the corresponding filters. Use
80 | the columns of the patterns and filters.
81 | Examples
82 | --------
83 | Calculate the CSP for the first two classes::
84 | >> w, a, d = calculate_csp(epo)
85 | >> # Apply the first two and the last two columns of the sorted
86 | >> # filter to the data
87 | >> filtered = apply_spatial_filter(epo, w[:, [0, 1, -2, -1]])
88 | >> # You'll probably want to get the log-variance along the time
89 | >> # axis, this should result in four numbers (one for each
90 | >> # channel)
91 | >> filtered = np.log(np.var(filtered, 0))
92 | Select two classes manually::
93 | >> w, a, d = calculate_csp(epo, [2, 5])
94 | Parameters
95 | ----------
96 | epo : epoched Data object
97 | this method relies on the ``epo`` to have three dimensions in
98 | the following order: class, time, channel
99 | classes : list of two ints, optional
100 | If ``None`` the first two different class indices found in
101 | ``epo.axes[0]`` are chosen automatically otherwise the class
102 | indices can be manually chosen by setting ``classes``
103 | average_trial_covariance : bool
104 | Returns
105 | -------
106 | v : 2d array
107 | the sorted spatial filters
108 | a : 2d array
109 | the sorted spatial patterns. Column i of a represents the
110 | pattern of the filter in column i of v.
111 | d : 1d array
112 | the variances of the components
113 | Raises
114 | ------
115 | AssertionError :
116 | If:
117 | * ``classes`` is not ``None`` and has less than two elements
118 | * ``classes`` is not ``None`` and the first two elements are
119 | not found in the ``epo``
120 | * ``classes`` is ``None`` but there are less than two
121 | different classes in the ``epo``
122 | See Also
123 | --------
124 | :func:`apply_spatial_filter`, :func:`apply_csp`, :func:`calculate_spoc`
125 | References
126 | ----------
127 | http://en.wikipedia.org/wiki/Common_spatial_pattern
128 | """
129 | if classes is None:
130 | # automagically find the first two different classidx
131 | # we don't use uniq, since it sorts the classidx first
132 | # first check if we have a least two diffeent idxs:
133 | unique_classes = np.unique(epo.y)
134 | assert len(unique_classes) == 2
135 | cidx1 = unique_classes[0]
136 | cidx2 = unique_classes[1]
137 | else:
138 | assert (len(classes) == 2 and
139 | classes[0] in epo.y and
140 | classes[1] in epo.y)
141 | cidx1 = classes[0]
142 | cidx2 = classes[1]
143 | epo1 = select_classes(epo, [cidx1])
144 | epo2 = select_classes(epo, [cidx2])
145 | if average_trial_covariance:
146 | # computing c1 as mean covariance of trial covariances:
147 | c1 = np.mean([np.cov(x) for x in epo1.X], axis=0)
148 | c2 = np.mean([np.cov(x) for x in epo2.X], axis=0)
149 | else:
150 | # we need a matrix of the form (channels, observations) so we stack
151 | # trials and time per channel together
152 | x1 = np.concatenate(epo1.X, axis=1)
153 | x2 = np.concatenate(epo2.X, axis=1)
154 | # compute covariance matrices of the two classes
155 | c1 = np.cov(x1)
156 | c2 = np.cov(x2)
157 | # solution of csp objective via generalized eigenvalue problem
158 | # in matlab the signature is v, d = eig(a, b)
159 |
160 | # d1, v1 = sp.linalg.eigh(c2, c1 + c2) # old instruction
161 | d, v = sp.linalg.eig(c2, c1 + c2)
162 | d = d.real
163 | # make sure the eigenvalues and -vectors are correctly sorted
164 | indx = np.argsort(d)
165 | # reverse
166 | indx = indx[::-1]
167 | d = d.take(indx)
168 | v = v.take(indx, axis=1)
169 |
170 | # Now compute patterns
171 | # old pattern computation
172 | # a = sp.linalg.inv(v).transpose()
173 | c_avg = (c1 + c2) / 2.0
174 |
175 | # compare
176 | # https://github.com/bbci/bbci_public/blob/
177 | # c7201e4e42f873cced2e068c6cbb3780a8f8e9ec/processing/proc_csp.m#L112
178 | # with W := v
179 | v_with_cov = np.dot(c_avg, v)
180 | source_cov = np.dot(np.dot(v.T, c_avg), v)
181 | # matlab-python comparison
182 | """
183 | v_with_cov = np.array([[1,2,-2],
184 | [3,-2,4],
185 | [5,1,0.3]])
186 |
187 | source_cov = np.array([[1,2,0.5],
188 | [2,0.6,4],
189 | [0.5,4,2]])
190 |
191 | sp.linalg.solve(source_cov.T, v_with_cov.T).T
192 | # for matlab
193 | v_with_cov = [[1,2,-2],
194 | [3,-2,4],
195 | [5,1,0.3]]
196 |
197 | source_cov = [[1,2,0.5],
198 | [2,0.6,4],
199 | [0.5,4,2]]
200 | v_with_cov / source_cov"""
201 |
202 | a = sp.linalg.solve(source_cov.T, v_with_cov.T).T
203 | return v, a, d
204 |
205 |
206 | def apply_csp_fast(epo, filt, columns=[0, -1]):
207 | """Apply the CSP filter.
208 |
209 | Apply the spacial CSP filter to the epoched data.
210 |
211 | Parameters
212 | ----------
213 | epo : epoched ``Data`` object
214 | this method relies on the ``epo`` to have three dimensions in
215 | the following order: class, time, channel
216 | filt : 2d array
217 | the CSP filter (i.e. the ``v`` return value from
218 | :func:`calculate_csp`)
219 | columns : array of ints, optional
220 | the columns of the filter to use. The default is the first and
221 | the last one.
222 |
223 | Returns
224 | -------
225 | epo : epoched ``Data`` object
226 | The channels from the original have been replaced with the new
227 | virtual CSP channels.
228 |
229 | Examples
230 | --------
231 |
232 | >> w, a, d = calculate_csp(epo)
233 | >> epo = apply_csp_fast(epo, w)
234 |
235 | See Also
236 | --------
237 | :func:`calculate_csp`
238 | :func:`apply_csp`
239 |
240 | """
241 | # getting only selected columns
242 | f = filt[:, columns]
243 |
244 | # pre-allocating filtered vector
245 | filtered = []
246 |
247 | # TODO: pre-allocate in the right way, no append please
248 |
249 | # filtering on each trial
250 | for trial_i in range(len(epo.X)):
251 | # time x filters
252 | this_filtered = np.dot(epo.X[trial_i].T, f)
253 | # to filters x time
254 | filtered.append(this_filtered.T)
255 | return SignalAndTarget(filtered, epo.y)
256 |
257 |
258 | def apply_csp_var_log(epo, filters, columns):
259 | csp_filtered = apply_csp_fast(epo, filters, columns)
260 | # 1 is t
261 | csp_filtered.X = np.array([
262 | np.log(np.var(trial, axis=1)) for trial in csp_filtered.X
263 | ])
264 | return csp_filtered
265 |
--------------------------------------------------------------------------------
/hgdecode/utils.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import sys
3 | import datetime
4 | import numpy as np
5 | import logging as log
6 | from os import getcwd
7 | from os import system
8 | from os import listdir
9 | from os import makedirs
10 | from os import remove
11 | from sys import platform
12 | from pickle import dump
13 | from os.path import join
14 | from os.path import exists
15 | from os.path import dirname
16 | from sklearn.metrics import confusion_matrix
17 |
18 | now_dir = ''
19 |
20 |
21 | def dash_printer(input_string='', manual_input=32):
22 | sys.stdout.write('-' * len(input_string) + '-' * manual_input + '\n')
23 |
24 |
25 | def print_manager(input_string='',
26 | print_style='normal',
27 | top_return=None,
28 | bottom_return=None):
29 | # top return
30 | if top_return is not None:
31 | sys.stdout.write('\n' * top_return)
32 |
33 | # input_string
34 | if print_style == 'normal':
35 | log.info(input_string)
36 | elif print_style == 'top-dashed':
37 | dash_printer(input_string)
38 | log.info(input_string)
39 | print_manager.last_print = input_string
40 | elif print_style == 'bottom-dashed':
41 | log.info(input_string)
42 | dash_printer(input_string)
43 | print_manager.last_print = input_string
44 | elif print_style == 'double-dashed':
45 | dash_printer(input_string)
46 | log.info(input_string)
47 | dash_printer(input_string)
48 | print_manager.last_print = input_string
49 | elif print_style == 'last':
50 | log.info(input_string)
51 | dash_printer(print_manager.last_print)
52 | print_manager.last_print = input_string
53 |
54 | # bottom return
55 | if bottom_return is not None:
56 | sys.stdout.write('\n' * bottom_return)
57 |
58 |
59 | def datetime_dir_format():
60 | now = datetime.datetime.now()
61 | second = str(now.second)
62 | minute = str(now.minute)
63 | hour = str(now.hour)
64 | day = str(now.day)
65 | month = str(now.month)
66 | year = str(now.year)
67 | if len(second) == 1:
68 | second = '0' + second
69 | if len(minute) == 1:
70 | minute = '0' + minute
71 | if len(hour) == 1:
72 | hour = '0' + hour
73 | if len(day) == 1:
74 | day = '0' + day
75 | if len(month) == 1:
76 | month = '0' + month
77 | return '_'.join(['-'.join([year, month, day]),
78 | '-'.join([hour, minute, second])])
79 |
80 |
81 | def touch_dir(directory):
82 | if exists(directory):
83 | return True
84 | else:
85 | makedirs(directory)
86 | return False
87 |
88 |
89 | def touch_file(file_path):
90 | if exists(file_path):
91 | return True
92 | else:
93 | return False
94 |
95 |
96 | def listdir2(path):
97 | l = listdir(path)
98 | idx = 0
99 | while idx < len(l):
100 | if l[idx][0] is '.':
101 | l.remove(l[idx])
102 | else:
103 | idx += 1
104 | return l
105 |
106 |
107 | def create_log(results_dir,
108 | learning_type='ml',
109 | algorithm_or_model_name='FBCSP_rLDA',
110 | subject_id=1,
111 | output_on_file=False,
112 | use_last_result_directory=False):
113 | # getting now_dir from global
114 | global now_dir
115 |
116 | # setting temporary results directory
117 | results_dir = join(results_dir,
118 | learning_type,
119 | algorithm_or_model_name)
120 |
121 | # setting now_dir if necessary
122 | if len(now_dir) is 0:
123 | if use_last_result_directory is True:
124 | dirs_in_folder = listdir(results_dir)
125 | dirs_in_folder.sort()
126 | now_dir = dirs_in_folder[-1]
127 | else:
128 | now_dir = datetime_dir_format()
129 |
130 | # setting log_file_dir
131 | log_file_dir = join(results_dir, now_dir)
132 |
133 | # setting subject_id_str
134 | if type(subject_id) is str:
135 | subject_str = subject_id
136 | else:
137 | subject_str = str(subject_id)
138 | if len(subject_str) == 1:
139 | subject_str = '0' + subject_str
140 | subject_str = 'subj' + subject_str
141 |
142 | # setting subject_results_dir
143 | subject_results_dir = join(log_file_dir, subject_str)
144 |
145 | # touching directories
146 | touch_dir(log_file_dir)
147 | touch_dir(subject_results_dir)
148 |
149 | if output_on_file is True:
150 | # setting log_file_name
151 | log_file_name = 'log.bin'
152 |
153 | # setting log_file_path
154 | log_file_path = join(log_file_dir, log_file_name)
155 |
156 | # creating the log file
157 | sys.stdout = open(log_file_path, 'w')
158 |
159 | # opening it using system commands
160 | if platform == 'linux':
161 | system('xdg-open ' + log_file_path.replace(' ', '\ '))
162 | elif platform == 'darwin': # macOSX
163 | system('open ' + log_file_path.replace(' ', '\ '))
164 | else:
165 | sys.stdout.write('platform {:s} still not supported'.format(
166 | platform))
167 |
168 | # setting the logging object configuration
169 | log.basicConfig(
170 | format='%(asctime)s | %(levelname)s: %(message)s',
171 | filemode='w',
172 | stream=sys.stdout,
173 | level=log.DEBUG
174 | )
175 | else:
176 | # setting the logging object configuration
177 | log.basicConfig(
178 | format='%(asctime)s | %(levelname)s: %(message)s',
179 | level=log.DEBUG,
180 | stream=sys.stdout
181 | )
182 |
183 | # printing current cycle information
184 | print_manager(
185 | '{} with {} on subject {}'.format(
186 | learning_type.upper(),
187 | algorithm_or_model_name,
188 | subject_id
189 | ),
190 | 'double-dashed',
191 | bottom_return=1
192 | )
193 |
194 | # returning subject_results_dir, in some case it can be helpful
195 | return subject_results_dir
196 |
197 |
198 | def ml_results_saver(exp, subj_results_dir):
199 | for fold_idx in range(exp.n_folds):
200 | # computing paths and directories
201 | fold_str = str(fold_idx + 1)
202 | if len(fold_str) is not 2:
203 | fold_str = '0' + fold_str
204 | fold_str = 'fold' + fold_str
205 | fold_dir = join(subj_results_dir, fold_str)
206 | touch_dir(fold_dir)
207 | file_path = join(fold_dir, 'fold_stats.pickle')
208 |
209 | # getting accuracies
210 | train_acc = exp.multi_class.train_accuracy[fold_idx]
211 | test_acc = exp.multi_class.test_accuracy[fold_idx]
212 |
213 | # getting y_true and y_pred for this fold
214 | train_true = exp.multi_class.train_labels[fold_idx]
215 | train_pred = exp.multi_class.train_predicted_labels[fold_idx]
216 | test_true = exp.multi_class.test_labels[fold_idx]
217 | test_pred = exp.multi_class.test_predicted_labels[fold_idx]
218 |
219 | # computing confusion matrices
220 | train_conf_mtx = confusion_matrix(train_true, train_pred)
221 | test_conf_mtx = confusion_matrix(test_true, test_pred)
222 |
223 | # creating results dictionary
224 | results = {
225 | 'train': {
226 | 'acc': train_acc,
227 | 'conf_mtx': train_conf_mtx.tolist()
228 | },
229 | 'test': {
230 | 'acc': test_acc,
231 | 'conf_mtx': test_conf_mtx.tolist()
232 | }
233 | }
234 |
235 | # saving results
236 | with open(file_path, 'wb') as f:
237 | dump(results, f)
238 |
239 |
240 | def csv_manager(csv_path, line):
241 | if not exists(csv_path):
242 | with open(csv_path, 'w', newline='') as f:
243 | writer = csv.writer(f, delimiter=',')
244 | writer.writerows([line])
245 | else:
246 | with open(csv_path, 'a', newline='') as f:
247 | writer = csv.writer(f)
248 | writer.writerows([line])
249 |
250 |
251 | def get_metrics_from_conf_mtx(conf_mtx, label_names=None):
252 | # creating standard label_names if not specified
253 | if label_names is None:
254 | label_names = ['label ' + str(x) for x in range(len(conf_mtx))]
255 |
256 | # computing true/false positive/negative from confusion matrix
257 | TP = np.diag(conf_mtx)
258 | FP = conf_mtx.sum(axis=0) - TP
259 | FN = conf_mtx.sum(axis=1) - TP
260 | TN = conf_mtx.sum() - (FP + FN + TP)
261 |
262 | # parsing to float
263 | TP = TP.astype(float)
264 | TN = TN.astype(float)
265 | FP = FP.astype(float)
266 | FN = FN.astype(float)
267 |
268 | # computing true positive rate (sensitivity, hit rate, recall)
269 | TPR = TP / (TP + FN)
270 |
271 | # computing true negative rate (specificity)
272 | TNR = TN / (TN + FP)
273 |
274 | # computing positive predictive value (precision)
275 | PPV = TP / (TP + FP)
276 |
277 | # computing negative predictive value
278 | NPV = TN / (TN + FN)
279 |
280 | # computing false positive rate (fall out)
281 | FPR = FP / (FP + TN)
282 |
283 | # computing false negative rate
284 | FNR = FN / (TP + FN)
285 |
286 | # computing false discovery rate
287 | FDR = FP / (TP + FP)
288 |
289 | # computing f1-score
290 | F1 = 2 * TP / (2 * TP + FP + FN)
291 |
292 | # computing accuracy on single label
293 | ACC = (TP + TN) / (TP + FP + FN + TN)
294 |
295 | # computing overall accuracy
296 | acc = TP.sum() / conf_mtx.sum()
297 |
298 | # pre-allocating metrics_report
299 | metrics_report = {x: None for x in label_names}
300 | metrics_report['acc'] = acc
301 |
302 | # filling metrics_report
303 | for idx, label in enumerate(label_names):
304 | metrics_report[label] = {
305 | 'TP': TP[idx],
306 | 'TN': TN[idx],
307 | 'FP': FP[idx],
308 | 'FN': FN[idx],
309 | 'TPR': TPR[idx],
310 | 'TNR': TNR[idx],
311 | 'PPV': PPV[idx],
312 | 'NPV': NPV[idx],
313 | 'FPR': FPR[idx],
314 | 'FNR': FNR[idx],
315 | 'FDR': FDR[idx],
316 | 'F1': F1[idx],
317 | 'ACC': ACC[idx]
318 | }
319 |
320 | # returning metrics_report
321 | return metrics_report
322 |
323 |
324 | def check_significant_digits(num):
325 | num = float(num)
326 | if num < 0:
327 | negative_flag = True
328 | num = -num
329 | else:
330 | negative_flag = False
331 | if num < 0.01: # from 0.009999
332 | num = np.round(num, 5)
333 | elif num < 0.1: # from 0.09999
334 | num = np.round(num, 4)
335 | else:
336 | num = np.round(num, 3)
337 | num = num * 100
338 | if num == 0:
339 | num = '0'
340 | elif num < 1:
341 | num = np.round(num, 3)
342 | num = str(num)
343 | num += '0' * (5 - len(num))
344 | num = num[0:4]
345 | elif num < 10:
346 | num = np.round(num, 2)
347 | num = str(num)
348 | num += '0' * (4 - len(num))
349 | num = num[0:4]
350 | elif num == 100:
351 | num = '100'
352 | else:
353 | num = np.round(num, 1)
354 | num = str(num)
355 | if negative_flag:
356 | num = '-' + num
357 | return num
358 |
359 |
360 | def get_subj_str(subj_id):
361 | return my_formatter(subj_id, 'subj')
362 |
363 |
364 | def get_fold_str(fold_id):
365 | return my_formatter(fold_id, 'fold')
366 |
367 |
368 | def my_formatter(num, name):
369 | num_str = str(num)
370 | if len(num_str) == 1:
371 | num_str = '0' + num_str
372 | return name + num_str
373 |
374 |
375 | def get_path(results_dir=None,
376 | learning_type='dl',
377 | algorithm_or_model_name=None,
378 | epoching=(-500, 4000),
379 | fold_type='single_subject',
380 | n_folds=2,
381 | deprecated=False,
382 | balanced_folds=True):
383 | # checking results_dir
384 | if results_dir is None:
385 | results_dir = join(dirname(dirname(dirname(getcwd()))), 'results')
386 |
387 | # checking algorithm_or_model_name
388 | if algorithm_or_model_name is None:
389 | if learning_type == 'ml':
390 | algorithm_or_model_name = 'FBCSP_rLDA'
391 | elif learning_type == 'dl':
392 | algorithm_or_model_name = 'DeepConvNet'
393 | else:
394 | raise ValueError(
395 | 'Invalid learning_type inputed: {}'.format(learning_type)
396 | )
397 |
398 | # checking epoching
399 | if epoching.__class__ is tuple or epoching.__class__ is list:
400 | epoching_str = str(epoching[0]) + '_' + str(epoching[1])
401 | elif epoching.__class__ is str:
402 | epoching_str = epoching
403 | else:
404 | raise ValueError(
405 | 'Invalid epoching type: {}'.format(epoching.__class__)
406 | )
407 |
408 | # checking fold_type
409 | folder = ''
410 | if fold_type == 'schirrmeister':
411 | folder += '1'
412 | elif fold_type == 'single_subject':
413 | if epoching_str == '-500_4000':
414 | folder += '2'
415 | elif epoching_str == '-1000_1000':
416 | folder += '3'
417 | elif epoching_str == '-1500_500':
418 | folder += '4'
419 | elif fold_type == 'cross_subject':
420 | if epoching_str == '-500_4000':
421 | folder += '5'
422 | elif epoching_str == '-1000_1000':
423 | folder += '6'
424 | elif fold_type == 'transfer_learning':
425 | if epoching_str == '-500_4000':
426 | folder += '7'
427 | else:
428 | folder += '8'
429 | elif fold_type == 'transfer_learning_frozen':
430 | folder += '9'
431 | else:
432 | raise ValueError(
433 | 'Invalid fold_type: {}'.format(fold_type)
434 | )
435 | folder += '_' + fold_type + '_' + epoching_str
436 |
437 | # checking for deprecated / not stratified
438 | if deprecated is True:
439 | if balanced_folds is True:
440 | folder = join('0_deprecated', folder)
441 | else:
442 | folder = join('0_deprecated', '#_not_stratified', folder)
443 |
444 | # building folder path
445 | folder_path = join(results_dir,
446 | 'hgdecode',
447 | learning_type,
448 | algorithm_or_model_name,
449 | folder)
450 |
451 | if (fold_type == 'single_subject') and (epoching_str == '-500_4000'):
452 | folder_path = join(folder_path, my_formatter(n_folds, 'fold'))
453 | elif fold_type == 'transfer_learning':
454 | folder_path = join(folder_path, 'train_size_' + str(n_folds))
455 | elif fold_type == 'transfer_learning_frozen':
456 | folder_path = join(folder_path, 'frozen_' + str(n_folds),
457 | 'train_size_128')
458 |
459 | return join(folder_path, listdir2(folder_path)[0])
460 |
461 |
462 | def clear_all_models(subj_results_dir):
463 | fold_folders = listdir2(subj_results_dir)
464 | for fold_folder in fold_folders:
465 | remove(join(subj_results_dir, fold_folder, 'net_best_val_loss.h5'))
466 |
--------------------------------------------------------------------------------
/ml_main.py:
--------------------------------------------------------------------------------
1 | from os import getcwd
2 | from os.path import join
3 | from os.path import dirname
4 | from collections import OrderedDict
5 | from numpy.random import RandomState
6 | from hgdecode.utils import create_log
7 | from hgdecode.utils import my_formatter
8 | from hgdecode.utils import ml_results_saver
9 | from hgdecode.classes import CrossValidation
10 | from hgdecode.loaders import ml_loader
11 | from hgdecode.experiments import FBCSPrLDAExperiment
12 |
13 | """
14 | SETTING PARAMETERS
15 | ------------------
16 | In the following, you have to set / modify all the parameters to use for
17 | further computation.
18 |
19 | Parameters
20 | ----------
21 | algorithm_name : str
22 | Machine Learning algorithm name that is going to be used
23 | channel_names : list
24 | Channels to use for computation
25 | data_dir : str
26 | Path to the directory that contains dataset
27 | name_to_start_codes : OrderedDict
28 | All possible classes names and codes in an ordered dict format
29 | results_dir : str
30 |
31 | subject_ids : tuple
32 | All the subject ids in a tuple; add or remove subjects to run the
33 | algorithm for them or not
34 | """
35 | # setting ml_algorithm
36 | algorithm_name = 'FBCSP_rLDA'
37 |
38 | # setting channel_names
39 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4',
40 | 'CP5', 'CP1', 'CP2', 'CP6',
41 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6',
42 | 'CP3', 'CPz', 'CP4',
43 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
44 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h',
45 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
46 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h',
47 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
48 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h']
49 |
50 | # setting data_dir & results_dir
51 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma')
52 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode')
53 |
54 | # setting name_to_start_codes
55 | name_to_start_codes = OrderedDict([('Right Hand', [1]),
56 | ('Left Hand', [2]),
57 | ('Rest', [3]),
58 | ('Feet', [4])])
59 |
60 | # setting random state
61 | random_state = RandomState(1234)
62 |
63 | # real useful hyperparameters
64 | standardize_mode = 0
65 | subject_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)
66 | ival = (-500, 4000)
67 | n_folds = 6
68 |
69 | # fold stuff
70 | ival_str = str(ival[0]) + '_' + str(ival[1])
71 | fold_dir = join(data_dir,
72 | 'stratified_fold_' + ival_str,
73 | my_formatter(n_folds, 'fold'))
74 |
75 | """
76 | MAIN CYCLE
77 | ----------
78 | For each subject, a new log will be created and the specific dataset loaded;
79 | this dataset will be used to create an instance of the experiment; then the
80 | experiment will be run. You can of course change all the experiment inputs
81 | to obtain different results.
82 | """
83 | for subject_id in subject_ids:
84 | # creating a log object
85 | subj_results_dir = create_log(
86 | results_dir=results_dir,
87 | learning_type='ml',
88 | algorithm_or_model_name=algorithm_name,
89 | subject_id=subject_id,
90 | output_on_file=False
91 | )
92 |
93 | # loading dataset
94 | cnt, clean_trial_mask = ml_loader(
95 | data_dir=data_dir,
96 | name_to_start_codes=name_to_start_codes,
97 | channel_names=channel_names,
98 | subject_id=subject_id,
99 | resampling_freq=250, # Schirrmeister: 250
100 | clean_ival_ms=ival, # Schirrmeister: (0, 4000)
101 | train_test_split=True, # Schirrmeister: True
102 | clean_on_all_channels=False, # Schirrmeister: True
103 | standardize_mode=standardize_mode # Schirrmeister: 2
104 | )
105 |
106 | # creating experiment instance
107 | exp = FBCSPrLDAExperiment(
108 | # signal-related inputs
109 | cnt=cnt,
110 | clean_trial_mask=clean_trial_mask,
111 | name_to_start_codes=name_to_start_codes,
112 | random_state=random_state,
113 | name_to_stop_codes=None, # Schirrmeister: None
114 | epoch_ival_ms=ival, # Schirrmeister: (-500, 4000)
115 |
116 | # bank filter-related inputs
117 | min_freq=[0, 10], # Schirrmeister: [0, 10]
118 | max_freq=[12, 122], # Schirrmeister: [12, 122]
119 | window=[6, 8], # Schirrmeister: [6, 8]
120 | overlap=[3, 4], # Schirrmeister: [3, 4]
121 | filt_order=3, # filt_order: 3
122 |
123 | # machine learning parameters
124 | n_folds=n_folds, # Schirrmeister: ?
125 | fold_file=join(fold_dir, my_formatter(subject_id, 'subj') + '.npz'),
126 | n_top_bottom_csp_filters=5, # Schirrmeister: 5
127 | n_selected_filterbands=None, # Schirrmeister: None
128 | n_selected_features=20, # Schirrmeister: 20
129 | forward_steps=2, # Schirrmeister: 2
130 | backward_steps=1, # Schirrmeister: 1
131 | stop_when_no_improvement=False, # Schirrmeister: False
132 | shuffle=False, # Schirrmeister: False
133 | average_trial_covariance=True # Schirrmeister: True
134 | )
135 |
136 | # running the experiment
137 | exp.run()
138 |
139 | # saving results for this subject
140 | ml_results_saver(exp=exp, subj_results_dir=subj_results_dir)
141 |
142 | # computing statistics for this subject
143 | CrossValidation.cross_validate(subj_results_dir=subj_results_dir,
144 | label_names=name_to_start_codes)
145 |
--------------------------------------------------------------------------------
/ml_main_cross_subject.py:
--------------------------------------------------------------------------------
1 | from os import getcwd
2 | from os.path import join
3 | from os.path import dirname
4 | from collections import OrderedDict
5 | from numpy.random import RandomState
6 | from hgdecode.utils import create_log
7 | from hgdecode.utils import ml_results_saver
8 | from hgdecode.loaders import CrossSubject
9 | from hgdecode.classes import CrossValidation
10 | from hgdecode.experiments import FBCSPrLDAExperiment
11 |
12 | """
13 | SETTING PARAMETERS
14 | Here you can set whatever parameter you want
15 | """
16 | # setting ml_algorithm
17 | algorithm_name = 'FBCSP_rLDA'
18 |
19 | # setting channel_names
20 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4',
21 | 'CP5', 'CP1', 'CP2', 'CP6',
22 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6',
23 | 'CP3', 'CPz', 'CP4',
24 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
25 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h',
26 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
27 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h',
28 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
29 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h']
30 |
31 | # setting data_dir & results_dir
32 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma')
33 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode')
34 |
35 | # setting name_to_start_codes
36 | name_to_start_codes = OrderedDict([('Right Hand', [1]),
37 | ('Left Hand', [2]),
38 | ('Rest', [3]),
39 | ('Feet', [4])])
40 |
41 | # setting random state
42 | random_state = RandomState(1234)
43 |
44 | # setting subject_ids
45 | subject_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)
46 |
47 | # %%
48 | """
49 | STARTING LOADING ROUTINE & COMPUTATION
50 | Here you can change some parameter in function calls as well
51 | """
52 | # creating a log object
53 | subj_results_dir = create_log(
54 | results_dir=results_dir,
55 | learning_type='ml',
56 | algorithm_or_model_name=algorithm_name,
57 | subject_id='subj_cross',
58 | output_on_file=False
59 | )
60 |
61 | # creating a cross-subject object for cross-subject validation
62 | cross_obj = CrossSubject(data_dir=data_dir,
63 | subject_ids=subject_ids,
64 | channel_names=channel_names,
65 | name_to_start_codes=name_to_start_codes,
66 | resampling_freq=250,
67 | train_test_split=True,
68 | clean_ival_ms=(-1000, 1000),
69 | epoch_ival_ms=(-1000, 1000),
70 | clean_on_all_channels=False)
71 |
72 | """
73 | Si potrebbe fare un soft parsing così trova le fold, poi si passa ad exp
74 | quello che gli serve (cnt all, clean all, fold all) e si butta tutto il
75 | resto...
76 | """
77 |
78 | # creating experiment instance
79 | exp = FBCSPrLDAExperiment(
80 | # signal-related inputs
81 | cnt=cross_obj.data,
82 | clean_trial_mask=cross_obj.clean_trial_mask,
83 | name_to_start_codes=name_to_start_codes,
84 | random_state=random_state,
85 | name_to_stop_codes=None, # Schirrmeister: None
86 | epoch_ival_ms=(-1000, 1000), # Schirrmeister: (-500, 4000)
87 | cross_subject_object=cross_obj,
88 |
89 | # bank filter-related inputs
90 | min_freq=[0, 10], # Schirrmeister: [0, 10]
91 | max_freq=[12, 122], # Schirrmeister: [12, 122]
92 | window=[6, 8], # Schirrmeister: [6, 8]
93 | overlap=[3, 4], # Schirrmeister: [3, 4]
94 | filt_order=3, # filt_order: 3
95 |
96 | # machine learning parameters
97 | n_folds=14, # Schirrmeister: ?
98 | n_top_bottom_csp_filters=5, # Schirrmeister: 5
99 | n_selected_filterbands=None, # Schirrmeister: None
100 | n_selected_features=20, # Schirrmeister: 20
101 | forward_steps=2, # Schirrmeister: 2
102 | backward_steps=1, # Schirrmeister: 1
103 | stop_when_no_improvement=False, # Schirrmeister: False
104 | shuffle=False, # Schirrmeister: False
105 | average_trial_covariance=True # Schirrmeister: True
106 | )
107 |
108 | # running the experiment
109 | exp.run()
110 |
111 | # saving results
112 | ml_results_saver(exp=exp, subj_results_dir=subj_results_dir)
113 |
114 | # at the very end, running cross-validation
115 | CrossValidation.cross_validate(subj_results_dir=subj_results_dir,
116 | label_names=name_to_start_codes)
117 |
--------------------------------------------------------------------------------
/schirrmeister_main.py:
--------------------------------------------------------------------------------
1 | from os import getcwd
2 | from os.path import join
3 | from os.path import dirname
4 | from collections import OrderedDict
5 | from numpy.random import RandomState
6 | from numpy import array
7 | from numpy import floor
8 | from numpy import setdiff1d
9 | from hgdecode.utils import create_log
10 | from hgdecode.utils import print_manager
11 | from hgdecode.loaders import ml_loader
12 | from hgdecode.classes import CrossValidation
13 | from hgdecode.experiments import DLExperiment
14 | from hgdecode.experiments import FBCSPrLDAExperiment
15 | from keras import backend as K
16 | from braindecode.datautil.trial_segment import \
17 | create_signal_target_from_raw_mne
18 | from hgdecode.utils import ml_results_saver
19 |
20 | """
21 | ONLY PARAMETER YOU CAN CHOSE
22 | ----------------------------
23 | """
24 | # set here what type of learning you want
25 | learning_type = 'ml'
26 |
27 | """
28 | SETTING OTHER PARAMETERS (YOU CANNOT MODIFY THAT)
29 | -------------------------------------------------
30 | """
31 | # setting channel_names
32 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4',
33 | 'CP5', 'CP1', 'CP2', 'CP6',
34 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6',
35 | 'CP3', 'CPz', 'CP4',
36 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
37 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h',
38 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
39 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h',
40 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
41 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h']
42 |
43 | # setting data_dir & results_dir
44 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma')
45 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode')
46 |
47 | # setting name_to_start_codes
48 | name_to_start_codes = OrderedDict([('Right Hand', [1]),
49 | ('Left Hand', [2]),
50 | ('Rest', [3]),
51 | ('Feet', [4])])
52 |
53 | # setting random_state
54 | random_state = RandomState(1234)
55 |
56 | # subject list
57 | subject_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)
58 |
59 | # computing algorithm_or_model_name and standardize_mode
60 | if learning_type == 'ml':
61 | algorithm_or_model_name = 'FBCSP_rLDA'
62 | standardize_mode = 0
63 | else:
64 | algorithm_or_model_name = 'DeepConvNet'
65 | standardize_mode = 2
66 |
67 | """
68 | MAIN CYCLE
69 | ----------
70 | """
71 | for subject_id in subject_ids:
72 | # creating a log object
73 | subj_results_dir = create_log(
74 | results_dir=results_dir,
75 | learning_type=learning_type,
76 | algorithm_or_model_name=algorithm_or_model_name,
77 | subject_id=subject_id,
78 | output_on_file=False
79 | )
80 |
81 | # loading cnt signal
82 | cnt, clean_trial_mask = ml_loader(
83 | data_dir=data_dir,
84 | name_to_start_codes=name_to_start_codes,
85 | channel_names=channel_names,
86 | subject_id=subject_id,
87 | resampling_freq=250, # Schirrmeister: 250
88 | clean_ival_ms=(-500, 4000), # Schirrmeister: (0, 4000)
89 | train_test_split=True, # Schirrmeister: True
90 | clean_on_all_channels=False, # Schirrmeister: True
91 | standardize_mode=standardize_mode # Schirrmeister: 2
92 | )
93 |
94 | # splitting two algorithms
95 | if learning_type == 'ml':
96 | # creating experiment instance
97 | exp = FBCSPrLDAExperiment(
98 | # signal-related inputs
99 | cnt=cnt,
100 | clean_trial_mask=clean_trial_mask,
101 | name_to_start_codes=name_to_start_codes,
102 | random_state=random_state,
103 | name_to_stop_codes=None, # Schirrmeister: None
104 | epoch_ival_ms=(-500, 4000), # Schirrmeister: (-500, 4000)
105 |
106 | # bank filter-related inputs
107 | min_freq=[0, 10], # Schirrmeister: [0, 10]
108 | max_freq=[12, 122], # Schirrmeister: [12, 122]
109 | window=[6, 8], # Schirrmeister: [6, 8]
110 | overlap=[3, 4], # Schirrmeister: [3, 4]
111 | filt_order=3, # filt_order: 3
112 |
113 | # machine learning parameters
114 | n_folds=0, # Schirrmeister: ?
115 | n_top_bottom_csp_filters=5, # Schirrmeister: 5
116 | n_selected_filterbands=None, # Schirrmeister: None
117 | n_selected_features=20, # Schirrmeister: 20
118 | forward_steps=2, # Schirrmeister: 2
119 | backward_steps=1, # Schirrmeister: 1
120 | stop_when_no_improvement=False, # Schirrmeister: False
121 | shuffle=False, # Schirrmeister: False
122 | average_trial_covariance=True # Schirrmeister: True
123 | )
124 |
125 | # running the experiment
126 | exp.run()
127 |
128 | # saving results for this subject
129 | ml_results_saver(exp=exp, subj_results_dir=subj_results_dir)
130 |
131 | # computing statistics for this subject
132 | CrossValidation.cross_validate(subj_results_dir=subj_results_dir,
133 | label_names=name_to_start_codes)
134 | elif learning_type == 'dl':
135 | # creating schirrmeister fold
136 | all_idxs = array(range(len(clean_trial_mask)))
137 | folds = [
138 | {
139 | 'train': all_idxs[:-160],
140 | 'test': all_idxs[-160:]
141 | }
142 | ]
143 | folds[0]['train'] = folds[0]['train'][clean_trial_mask[:-160]]
144 | folds[0]['test'] = folds[0]['test'][clean_trial_mask[-160:]]
145 |
146 | # adding validation
147 | valid_idxs = array(range(int(floor(len(clean_trial_mask) * 0.1))))
148 | folds[0]['train'] = setdiff1d(folds[0]['train'], valid_idxs)
149 | folds[0]['valid'] = valid_idxs
150 |
151 | # parsing cnt to epoched data
152 | print_manager('Epoching...')
153 | epo = create_signal_target_from_raw_mne(cnt,
154 | name_to_start_codes,
155 | (-500, 4000))
156 | print_manager('DONE!!', bottom_return=1)
157 |
158 | # # cleaning epoched signal with mask
159 | # print_manager('cleaning with mask...')
160 | # epo.X = epo.X[clean_trial_mask]
161 | # epo.y = epo.y[clean_trial_mask]
162 | # print_manager('DONE!!', 'last', bottom_return=1)
163 |
164 | # creating cv instance
165 | cv = CrossValidation(X=epo.X, y=epo.y, shuffle=False)
166 |
167 | # creating EEGDataset for current fold
168 | dataset = cv.create_dataset(fold=folds[0])
169 |
170 | # clearing TF graph (https://github.com/keras-team/keras/issues/3579)
171 | print_manager('CLEARING KERAS BACKEND', print_style='double-dashed')
172 | K.clear_session()
173 | print_manager(print_style='last', bottom_return=1)
174 |
175 | # creating experiment instance
176 | exp = DLExperiment(
177 | # non-default inputs
178 | dataset=dataset,
179 | model_name=algorithm_or_model_name,
180 | results_dir=results_dir,
181 | subj_results_dir=subj_results_dir,
182 | name_to_start_codes=name_to_start_codes,
183 | random_state=random_state,
184 | fold_idx=0,
185 |
186 | # hyperparameters
187 | dropout_rate=0.5, # Schirrmeister: 0.5
188 | learning_rate=1 * 1e-4, # Schirrmeister: ?
189 | batch_size=32, # Schirrmeister: 512
190 | epochs=1000, # Schirrmeister: ?
191 | early_stopping=False, # Schirrmeister: ?
192 | monitor='val_acc', # Schirrmeister: ?
193 | min_delta=0.0001, # Schirrmeister: ?
194 | patience=5, # Schirrmeister: ?
195 | loss='categorical_crossentropy', # Schirrmeister: ad hoc
196 | optimizer='Adam', # Schirrmeister: Adam
197 | shuffle=True, # Schirrmeister: ?
198 | crop_sample_size=None, # Schirrmeister: 1125
199 | crop_step=None, # Schirrmeister: 1
200 |
201 | # other parameters
202 | subject_id=subject_id,
203 | data_generator=False, # Schirrmeister: True
204 | save_model_at_each_epoch=False
205 | )
206 |
207 | # training
208 | exp.train()
209 |
210 | # computing cross-validation
211 | CrossValidation.cross_validate(subj_results_dir=subj_results_dir,
212 | label_names=name_to_start_codes)
213 |
--------------------------------------------------------------------------------
/sub_routines/dl_cross_validation.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | from hgdecode.classes import CrossValidation
4 |
5 | datetime_dir = '/Users/davidemiani/OneDrive - Alma Mater Studiorum ' \
6 | 'Università di Bologna/TesiMagistrale_DavideMiani/' \
7 | 'results/hgdecode/dl/DeepConvNet/2019-01-18_13-33-01'
8 | subj_dirs = os.listdir(datetime_dir)
9 | subj_dirs.sort()
10 | subj_dirs.remove('model_report.txt')
11 | subj_dirs.remove('statistics')
12 | subj_dirs = [os.path.join(datetime_dir, x) for x in subj_dirs]
13 |
14 | name_to_start_codes = OrderedDict([('Right Hand', [1]),
15 | ('Left Hand', [2]),
16 | ('Rest', [3]),
17 | ('Feet', [4])])
18 |
19 | for subj_dir in subj_dirs:
20 | CrossValidation.cross_validate(subj_results_dir=subj_dir,
21 | label_names=name_to_start_codes)
22 |
--------------------------------------------------------------------------------
/sub_routines/latex_tabular_parser.py:
--------------------------------------------------------------------------------
1 | import os
2 | from csv import reader
3 | from hgdecode.utils import get_path
4 | from hgdecode.utils import check_significant_digits
5 |
6 | """
7 | SET HERE YOUR PARAMETERS
8 | """
9 | # to find file parameters
10 | results_dir = None
11 | learning_type = 'dl'
12 | algorithm_or_model_name = None
13 | epoching = '-1500_500'
14 | fold_type = 'single_subject'
15 | n_folds = 12
16 | deprecated = True
17 | balanced_fold = True
18 |
19 | # metrics parameter
20 | label = 'Feet' # Feet, LeftHand, Rest or RightHand
21 | metric_type = 'overall' # label or overall
22 | metric = 'acc'
23 |
24 | """
25 | GETTING PATHS
26 | """
27 | # getting folder path
28 | folder_path = get_path(
29 | results_dir=results_dir,
30 | learning_type=learning_type,
31 | algorithm_or_model_name=algorithm_or_model_name,
32 | epoching=epoching,
33 | fold_type=fold_type,
34 | n_folds=n_folds,
35 | deprecated=deprecated,
36 | balanced_folds=balanced_fold
37 | )
38 |
39 | # getting file_path
40 | file_path = os.path.join(folder_path, 'statistics', 'tables')
41 | if metric_type == 'overall':
42 | file_path = os.path.join(file_path, metric + '.csv')
43 | else:
44 | file_path = os.path.join(file_path, label, metric + '.csv')
45 |
46 | """
47 | COMPUTATION START HERE
48 | """
49 | with open(file_path) as f:
50 | csv = list(reader(f))
51 |
52 | n_folds = len(csv[0]) - 2
53 | columns = ['&\\textbf{' + str(x + 1) + '}\n' for x in range(n_folds)]
54 |
55 | output = '\\begin{table}[H]\n\\footnotesize\n\\centering\n\\begin{tabular}' + \
56 | '{|c|' + 'c' * n_folds + '|cc|}\n\\hline\n' + \
57 | '&\multicolumn{' + str(n_folds) + '}{c|}{\\textbf{fold}}& ' + \
58 | '&\n\\\\\n\\textbf{subj}\n'
59 | for head in columns:
60 | output += head
61 | output += '&\\textbf{mean}\n&\\textbf{std}\n\\\\\n\hline\hline\n'
62 |
63 | # removing header
64 | csv = csv[1:]
65 | total_m = 0
66 | total_s = 0
67 |
68 | for idx, current_row in enumerate(csv):
69 | if idx % 2 is 0:
70 | output += '\\rowcolor[gray]{.9}\n'
71 | else:
72 | output += '\\rowcolor[gray]{.8}\n'
73 | output += '\\textbf{' + str(idx + 1) + '}\n'
74 | for idy, current_col in enumerate(current_row):
75 | output += '&' + check_significant_digits(current_col) + '\n'
76 | output += '\\\\\n'
77 | total_m += float(current_row[-2])
78 | total_s += float(current_row[-1])
79 | total_m /= len(csv)
80 | total_s /= len(csv)
81 | total_m = check_significant_digits(total_m)
82 | total_s = check_significant_digits(total_s)
83 |
84 | caption = learning_type + ' ' + metric + ' ' + \
85 | epoching.replace('_', ',') + ' ' + str(n_folds) + ' fold'
86 | output += '\\hline\n\\multicolumn{' + str(n_folds + 1) + '}' + \
87 | '{|r|}{\\textbf{media totale}}\n&' + total_m + '\n&' + \
88 | total_s + '\n\\\\\n\\hline\n\\end{tabular}\n' + \
89 | '\\caption{' + caption + '}\n\\label{' + caption + '}\n' + \
90 | '\\end{table}'
91 | print(output)
92 |
--------------------------------------------------------------------------------
/sub_routines/latex_tabular_parser_cross_subj.py:
--------------------------------------------------------------------------------
1 | import os
2 | from csv import reader
3 | from hgdecode.utils import check_significant_digits
4 |
5 | results_dir = '/Users/davidemiani/OneDrive - Alma Mater Studiorum ' \
6 | 'Università di Bologna/TesiMagistrale_DavideMiani/' \
7 | 'results/hgdecode'
8 | learning_type = 'dl' # dl or ml
9 | algo_or_model_name = 'DeepConvNet' # DeepConvNet or FBCSP_rLDA
10 | datetime = '2019-01-21_11-38-41'
11 | epoch_ival_ms = '-1000, 1000' # str type
12 | tables_dir = os.path.join(results_dir,
13 | learning_type,
14 | algo_or_model_name,
15 | datetime,
16 | 'statistics',
17 | 'tables')
18 | label_names = ['RightHand', 'LeftHand', 'Rest', 'Feet']
19 |
20 | # pre-allocating csv
21 | csv = []
22 |
23 | # getting accuracy
24 | with open(os.path.join(tables_dir, 'acc.csv')) as f:
25 | temp = list(reader(f))
26 | temp = temp[1]
27 | csv.append(temp)
28 |
29 | # getting precision
30 | for label in label_names:
31 | with open(os.path.join(tables_dir, label, 'prec.csv')) as f:
32 | temp = list(reader(f))
33 | temp = temp[1]
34 | csv.append(temp)
35 |
36 | # getting f1 score
37 | for label in label_names:
38 | with open(os.path.join(tables_dir, label, 'f1.csv')) as f:
39 | temp = list(reader(f))
40 | temp = temp[1]
41 | csv.append(temp)
42 |
43 | # transposing csv
44 | csv = list(map(list, zip(*csv)))
45 |
46 | # cutting away last two rows (mean and std)
47 | csv_2 = csv[-2:]
48 | csv = csv[0:-2]
49 |
50 | output = '\\begin{table}[H]\n\\footnotesize\n\\centering\n\\begin{tabular}' + \
51 | '{|c|ccccccccc|}\n\\hline\n' + \
52 | '\\textbf{fold}\n&' + \
53 | '\\textbf{acc}\n&' + \
54 | '\\textbf{pr 1}\n&' + \
55 | '\\textbf{pr 2}\n&' + \
56 | '\\textbf{pr 3}\n&' + \
57 | '\\textbf{pr 4}\n&' + \
58 | '\\textbf{f1 1}\n&' + \
59 | '\\textbf{f1 2}\n&' + \
60 | '\\textbf{f1 3}\n&' + \
61 | '\\textbf{f1 4}\n' + \
62 | '\\\\\n\\hline\\hline\n'
63 |
64 | for idx, current_row in enumerate(csv):
65 | if idx % 2 is 0:
66 | output += '\\rowcolor[gray]{.9}\n'
67 | else:
68 | output += '\\rowcolor[gray]{.8}\n'
69 | output += '\\textbf{' + str(idx + 1) + '}\n'
70 | for idy, current_col in enumerate(current_row):
71 | output += '&' + check_significant_digits(current_col) + '\n'
72 | output += '\\\\\n'
73 | output += '\\hline\n'
74 |
75 | for idx, current_row in enumerate(csv_2):
76 | if idx == 0:
77 | output += '\\textbf{mean}\n'
78 | else:
79 | output += '\\textbf{std}\n'
80 | for idy, current_col in enumerate(current_row):
81 | output += '&' + check_significant_digits(current_col) + '\n'
82 | output += '\\\\\n'
83 | output += '\\hline\n'
84 |
85 | caption = learning_type + ' cross-subject validation ' + epoch_ival_ms
86 | output += '\\end{tabular}\n' + \
87 | '\\caption{' + caption + '}\n\\label{' + caption + '}\n' + \
88 | '\\end{table}'
89 | print(output)
90 |
--------------------------------------------------------------------------------
/sub_routines/latex_tabular_parser_transfer_learning.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from csv import reader
4 | from scipy.stats import ttest_rel
5 | from hgdecode.utils import get_path
6 | from hgdecode.utils import check_significant_digits
7 |
8 | """
9 | SET HERE YOUR PARAMETERS
10 | """
11 | ival = (-1000, 1000)
12 | train_trials_list = [4, 8, 16, 32, 64, 128]
13 | reference = 1 # 0 for ML cross, 1 for DL cross, 2 for TL4 ecc.
14 | p_flag = False # if true, it will print p value too.
15 |
16 | """
17 | GETTING PATHS
18 | """
19 | folder_paths = [
20 | get_path(
21 | results_dir=None,
22 | learning_type=x,
23 | algorithm_or_model_name=y,
24 | epoching=ival,
25 | fold_type='cross_subject',
26 | n_folds=None,
27 | deprecated=False
28 | )
29 | for x, y in zip(['ml', 'dl'], ['FBCSP_rLDA', 'DeepConvNet'])
30 | ]
31 |
32 | folder_paths += [
33 | get_path(
34 | results_dir=None,
35 | learning_type='dl',
36 | algorithm_or_model_name='DeepConvNet',
37 | epoching=ival,
38 | fold_type='transfer_learning',
39 | n_folds=x,
40 | deprecated=False
41 | )
42 | for x in train_trials_list
43 | ]
44 |
45 | # getting file_path
46 | csv_paths = [
47 | os.path.join(x, 'statistics', 'tables', 'acc.csv')
48 | for x in folder_paths
49 | ]
50 |
51 | """
52 | GETTING DATA FROM CSV FILES
53 | """
54 | subj_data = []
55 | mean_data = []
56 | stdd_data = []
57 | for idx, csv_path in enumerate(csv_paths):
58 | with open(csv_path) as f:
59 | csv = list(reader(f))
60 | csv = csv[1:]
61 | csv = [
62 | list(map(float, csv[x]))
63 | for x in range(len(csv))
64 | ]
65 | if idx < 2:
66 | subj_data.append(csv[0][:-2])
67 | mean_data.append(csv[0][-2])
68 | stdd_data.append(csv[0][-1])
69 | else:
70 | temp_data = []
71 | for csv_line in csv:
72 | temp_data.append(csv_line[-2])
73 | subj_data.append(temp_data)
74 | mean_data.append(np.mean(temp_data))
75 | stdd_data.append(np.std(temp_data))
76 |
77 | """
78 | COMPUTING PERC AND PVAL
79 | """
80 | perc_data = []
81 | pval_data = []
82 | for idx in range(len(mean_data)):
83 | if idx is reference:
84 | perc_data.append(0.)
85 | pval_data.append(float('nan'))
86 | else:
87 | perc_data.append((mean_data[idx] - mean_data[reference]) /
88 | mean_data[reference])
89 | pval_data.append(ttest_rel(subj_data[idx], subj_data[reference])[1])
90 |
91 | """
92 | GENERAL FORMATTING
93 | """
94 | n_subjs = len(subj_data[0])
95 | columns = [['subj'] + list(map(str, range(1, n_subjs + 1))) +
96 | ['mean', 'std', '$\\Delta_{\\textbf{\\%}}$', '$p$']]
97 | header = ['ML', 'DL', '4', '8', '16', '32', '64', '128']
98 | for idx, head in enumerate(header):
99 | temp = [check_significant_digits(str(subj_data[idx][x]))
100 | for x in range(n_subjs)]
101 | columns.append(
102 | [head] +
103 | temp +
104 | [str(check_significant_digits(mean_data[idx]))] +
105 | [str(check_significant_digits(stdd_data[idx]))] +
106 | [str(check_significant_digits(perc_data[idx]))] +
107 | [str(check_significant_digits(pval_data[idx]))]
108 | )
109 | rows = list(map(list, zip(*columns)))
110 | if p_flag is False:
111 | rows.pop()
112 |
113 | """
114 | CREATING LATEX TABULAR CODE
115 | """
116 | # pre-allocating output
117 | output = ''
118 |
119 | # opening table
120 | output += '\\begin{table}[H]\n\\footnotesize\n\\centering\n'
121 | output += '\\begin{tabular}{|c|M{1.4cm}M{1.4cm}|'
122 | output += 'c' * len(train_trials_list) + '|}\n'
123 | output += '\\hline\n&\multicolumn{2}{c|}{\\textbf{cross-soggetto}}\n'
124 | output += '&\multicolumn{' + str(len(train_trials_list)) + '}{c|}'
125 | output += '{\\textbf{transfer learning}}\n\\\\\n'
126 |
127 | # first row is an header
128 | for idx, col in enumerate(rows[0]):
129 | if idx == 0:
130 | output += '\\textbf{' + col + '}\n'
131 | else:
132 | output += '&\\textbf{' + col + '}\n'
133 | output += '\\\\\n\\hline\n\\hline\n'
134 |
135 | # creating iterator and jumping the first element (header)
136 | iterator = iter(rows)
137 | next(iterator)
138 |
139 | for idx, row in enumerate(iterator):
140 | if idx % 2 == 0:
141 | output += '\\rowcolor[gray]{.9}\n'
142 | else:
143 | output += '\\rowcolor[gray]{.8}\n'
144 | for idy, col in enumerate(row):
145 | if idy == 0:
146 | output += '\\textbf{' + col + '}\n'
147 | else:
148 | output += '&' + col + '\n'
149 | output += '\\\\\n'
150 | if idx == n_subjs - 1:
151 | output += '\\hline\n\\hline\n'
152 |
153 | output += '\\hline\n\\end{tabular}'
154 | output += '\n\\caption{tl table}\n\\label{tl table}\n'
155 | output += '\\end{table}'
156 | print(output)
157 |
--------------------------------------------------------------------------------
/sub_routines/latex_tabular_parser_transfer_learning_frozen_layers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from csv import reader
4 | from scipy.stats import ttest_rel
5 | from hgdecode.utils import get_path
6 | from hgdecode.utils import check_significant_digits
7 |
8 | """
9 | SET HERE YOUR PARAMETERS
10 | """
11 | ival = (-500, 4000)
12 | frozen_layers_list = [1, 2, 3, 4, 5, 6, -5, -4, -3, -2, -1]
13 | reference = 0 # 0 for ML cross, 1 for DL cross, 2 for TL4 ecc.
14 | p_flag = False # if true, it will print p value too.
15 |
16 | """
17 | GETTING PATHS
18 | """
19 | folder_paths = [get_path(
20 | results_dir=None,
21 | learning_type='dl',
22 | algorithm_or_model_name=None,
23 | epoching=ival,
24 | fold_type='cross_subject',
25 | n_folds=None,
26 | deprecated=False
27 | )]
28 |
29 | folder_paths += [get_path(
30 | results_dir=None,
31 | learning_type='dl',
32 | algorithm_or_model_name=None,
33 | epoching=ival,
34 | fold_type='transfer_learning',
35 | n_folds=128,
36 | deprecated=False
37 | )]
38 |
39 | folder_paths += [
40 | get_path(
41 | results_dir=None,
42 | learning_type='dl',
43 | algorithm_or_model_name='DeepConvNet',
44 | epoching=ival,
45 | fold_type='transfer_learning_frozen',
46 | n_folds=x,
47 | deprecated=False
48 | )
49 | for x in frozen_layers_list
50 | ]
51 |
52 | # getting file_path
53 | csv_paths = [
54 | os.path.join(x, 'statistics', 'tables', 'acc.csv')
55 | for x in folder_paths
56 | ]
57 |
58 | """
59 | GETTING DATA FROM CSV FILES
60 | """
61 | subj_data = []
62 | mean_data = []
63 | stdd_data = []
64 | for idx, csv_path in enumerate(csv_paths):
65 | with open(csv_path) as f:
66 | csv = list(reader(f))
67 | csv = csv[1:]
68 | csv = [
69 | list(map(float, csv[x]))
70 | for x in range(len(csv))
71 | ]
72 | if idx < 1:
73 | subj_data.append(csv[0][:-2])
74 | mean_data.append(csv[0][-2])
75 | stdd_data.append(csv[0][-1])
76 | else:
77 | temp_data = []
78 | for csv_line in csv:
79 | temp_data.append(csv_line[-2])
80 | subj_data.append(temp_data)
81 | mean_data.append(np.mean(temp_data))
82 | stdd_data.append(np.std(temp_data))
83 |
84 | """
85 | COMPUTING PERC AND PVAL
86 | """
87 | perc_data = []
88 | pval_data = []
89 | for idx in range(len(mean_data)):
90 | if idx is reference:
91 | perc_data.append(0.)
92 | pval_data.append(float('nan'))
93 | else:
94 | perc_data.append((mean_data[idx] - mean_data[reference]) /
95 | mean_data[reference])
96 | pval_data.append(ttest_rel(subj_data[idx], subj_data[reference])[1])
97 |
98 | """
99 | GENERAL FORMATTING
100 | """
101 | n_subjs = len(subj_data[0])
102 | columns = [['subj'] + list(map(str, range(1, n_subjs + 1))) +
103 | ['mean', 'std', '$\\Delta_{\\textbf{\\%}}$', '$p$']]
104 | header = ['CL', '0'] + list(map(str, frozen_layers_list))
105 | for idx, head in enumerate(header):
106 | temp = [check_significant_digits(str(subj_data[idx][x]))
107 | for x in range(n_subjs)]
108 | columns.append(
109 | [head] +
110 | temp +
111 | [str(check_significant_digits(mean_data[idx]))] +
112 | [str(check_significant_digits(stdd_data[idx]))] +
113 | [str(check_significant_digits(perc_data[idx]))] +
114 | [str(check_significant_digits(pval_data[idx]))]
115 | )
116 | rows = list(map(list, zip(*columns)))
117 | if p_flag is False:
118 | rows.pop()
119 |
120 | """
121 | CREATING LATEX TABULAR CODE
122 | """
123 | # pre-allocating output
124 | output = ''
125 |
126 | # opening table
127 | output += '\\begin{table}[H]\n\\footnotesize\n\\centering\n'
128 | output += '\\begin{tabular}{|c|cc|'
129 | output += 'c' * len(frozen_layers_list) + '|}\n'
130 | output += '\\hline\n&&\n'
131 | output += '&\multicolumn{' + str(len(frozen_layers_list)) + '}{c|}'
132 | output += '{\\textbf{transfer learning con strati congelati}}\n\\\\\n'
133 |
134 | # first row is an header
135 | for idx, col in enumerate(rows[0]):
136 | if idx == 0:
137 | output += '\\textbf{' + col + '}\n'
138 | else:
139 | output += '&\\textbf{' + col + '}\n'
140 | output += '\\\\\n\\hline\n\\hline\n'
141 |
142 | # creating iterator and jumping the first element (header)
143 | iterator = iter(rows)
144 | next(iterator)
145 |
146 | for idx, row in enumerate(iterator):
147 | if idx % 2 == 0:
148 | output += '\\rowcolor[gray]{.9}\n'
149 | else:
150 | output += '\\rowcolor[gray]{.8}\n'
151 | for idy, col in enumerate(row):
152 | if idy == 0:
153 | output += '\\textbf{' + col + '}\n'
154 | else:
155 | output += '&' + col + '\n'
156 | output += '\\\\\n'
157 | if idx == n_subjs - 1:
158 | output += '\\hline\n\\hline\n'
159 |
160 | output += '\\hline\n\\end{tabular}'
161 | output += '\n\\caption{tl table}\n\\label{tl table}\n'
162 | output += '\\end{table}'
163 | print(output)
164 |
--------------------------------------------------------------------------------
/sub_routines/learning_curve.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from pylab import savefig
4 | from pickle import load
5 | from os.path import join
6 | from os.path import dirname
7 | from hgdecode.utils import touch_dir
8 | from hgdecode.utils import get_subj_str
9 | from hgdecode.utils import get_fold_str
10 | from hgdecode.utils import get_path
11 |
12 | """
13 | SET HERE YOUR PARAMETERS
14 | """
15 | # to find file parameters
16 | results_dir = None
17 | learning_type = 'dl'
18 | algorithm_or_model_name = None
19 | epoching = '-500_4000'
20 | fold_type = 'single_subject'
21 | n_folds_list = [2, 4, 6, 8, 10, 12] # must be a list of integer
22 | deprecated = False
23 |
24 | fontsize_1 = 35
25 | fontsize_2 = 27.5
26 | fig_size = (22, 7.5)
27 |
28 | """
29 | GETTING PATHS
30 | """
31 | # trainings stuff
32 | folder_paths = [
33 | get_path(
34 | results_dir=results_dir,
35 | learning_type=learning_type,
36 | algorithm_or_model_name=algorithm_or_model_name,
37 | epoching=epoching,
38 | fold_type=fold_type,
39 | n_folds=x,
40 | deprecated=deprecated
41 | )
42 | for x in n_folds_list
43 | ]
44 | n_trainings = len(folder_paths)
45 |
46 | # saving stuff
47 | savings_dir = join(dirname(dirname(folder_paths[0])), 'learning_curve')
48 | touch_dir(savings_dir)
49 |
50 | """
51 | SUBJECTS STUFF
52 | """
53 | # subject stuff
54 | n_trials_list = [480, 973, 1040, 1057, 880, 1040, 1040, 814, 1040, 1040,
55 | 1040, 1040, 950, 1040]
56 | n_subjects = len(n_trials_list)
57 | subj_str_list = [get_subj_str(x + 1) for x in range(n_subjects)]
58 |
59 | """
60 | COMPUTATION STARTS HERE
61 | """
62 | # pre-allocating results dictionary
63 | results = {
64 | 'n_folds': [],
65 | 'n_trials': [],
66 | 'n_train_trials': [],
67 | 'n_valid_trials': [],
68 | 'n_test_trials': [],
69 | 'perc_train_trials': [],
70 | 'perc_valid_trials': [],
71 | 'perc_test_trials': [],
72 | 'm_acc': [],
73 | 's_acc': []
74 | }
75 |
76 | # cycling on subject
77 | for subj, current_n_trials in zip(subj_str_list, n_trials_list):
78 | n_folds = []
79 | n_trials = []
80 | n_train_trials = []
81 | n_valid_trials = []
82 | n_test_trials = []
83 | perc_train_trials = []
84 | perc_valid_trials = []
85 | perc_test_trials = []
86 | m_acc = []
87 | s_acc = []
88 |
89 | # cycling on all possible fold splits
90 | for idx, current_n_folds in enumerate(n_folds_list):
91 | n_folds.append(current_n_folds)
92 | n_trials.append(current_n_trials)
93 | if learning_type is 'dl':
94 | n_valid_trials.append(int(np.floor(n_trials[idx] * 0.1)))
95 | else:
96 | n_valid_trials.append(0)
97 | n_train_trials.append(
98 | int(np.ceil(n_trials[idx] / n_folds[idx]) * (n_folds[idx] - 1)) -
99 | n_valid_trials[idx])
100 | n_test_trials.append(n_trials[idx] - n_train_trials[idx] -
101 | n_valid_trials[idx])
102 | perc_train_trials.append(
103 | np.round(n_train_trials[idx] / n_trials[idx] * 100, 1))
104 | perc_valid_trials.append(
105 | np.round(n_valid_trials[idx] / n_trials[idx] * 100, 1))
106 | perc_test_trials.append(
107 | np.round(100 - perc_valid_trials[idx] - perc_train_trials[idx], 1))
108 |
109 | # cycling on folds
110 | folds_acc = []
111 | for fold_str in [get_fold_str(x + 1) for x in range(n_folds[idx])]:
112 | file_path = join(folder_paths[idx],
113 | subj, fold_str, 'fold_stats.pickle')
114 | with open(file_path, 'rb') as f:
115 | fold_stats = load(f)
116 | folds_acc.append(fold_stats['test']['acc'])
117 | m_acc.append(np.mean(folds_acc) * 100)
118 | s_acc.append(np.std(folds_acc) * 100)
119 |
120 | # assigning results for this subject
121 | results['n_folds'].append(n_folds)
122 | results['n_trials'].append(n_trials)
123 | results['n_train_trials'].append(n_train_trials)
124 | results['n_valid_trials'].append(n_valid_trials)
125 | results['n_test_trials'].append(n_test_trials)
126 | results['perc_train_trials'].append(perc_train_trials)
127 | results['perc_valid_trials'].append(perc_valid_trials)
128 | results['perc_test_trials'].append(perc_test_trials)
129 | results['m_acc'].append(m_acc)
130 | results['s_acc'].append(s_acc)
131 |
132 | # plotting learning curve for this subject
133 | m_acc = np.array(m_acc)
134 | s_acc = np.array(s_acc)
135 | plot_path = join(savings_dir, subj)
136 | if learning_type is 'dl':
137 | title = '{} learning curve\n'.format(subj) + \
138 | '({} samples, {} validation samples)'.format(
139 | n_trials[0], n_valid_trials[0])
140 | else:
141 | title = '{} learning curve\n({} samples)'.format(subj, n_trials[0])
142 | x_tick_labels = ['{}\n({} folds)'.format(trials, folds)
143 | for trials, folds in zip(n_train_trials, n_folds)]
144 | plt.figure(dpi=100, figsize=(12.8, 7.2), facecolor='w', edgecolor='k')
145 | plt.style.use('seaborn-whitegrid')
146 | plt.errorbar(x=n_folds, y=m_acc, yerr=s_acc,
147 | fmt='-.o', color='b', ecolor='r',
148 | linewidth=2, elinewidth=3, capsize=20, capthick=2)
149 | plt.xlabel('training samples', fontsize=25)
150 | plt.ylabel('accuracy (%)', fontsize=25)
151 | plt.xticks(n_folds, labels=x_tick_labels, fontsize=20)
152 | plt.yticks(fontsize=20)
153 | plt.title(title, fontsize=25)
154 |
155 | # saving figure
156 | savefig(plot_path, bbox_inches='tight')
157 |
158 | # getting data for last plot
159 | n_trials = np.array(results['n_trials'])
160 | n_train_trials = np.array(results['n_train_trials'])
161 | n_valid_trials = np.array(results['n_valid_trials'])
162 |
163 | # averaging data
164 | m_n_trials = int(np.round(np.mean(n_trials, axis=0))[0].tolist())
165 | s_n_trials = int(np.round(np.std(n_trials, axis=0))[0].tolist())
166 | m_n_train_trials = np.round(np.mean(n_train_trials, axis=0)).tolist()
167 | s_n_train_trials = np.round(np.std(n_train_trials, axis=0)).tolist()
168 | m_n_train_trials = list(map(int, m_n_train_trials))
169 | s_n_train_trials = list(map(int, s_n_train_trials))
170 | m_n_valid_trials = np.round(np.mean(n_valid_trials, axis=0)).tolist()
171 | s_n_valid_trials = np.round(np.std(n_valid_trials, axis=0)).tolist()
172 | m_n_valid_trials = list(map(int, m_n_valid_trials))
173 | s_n_valid_trials = list(map(int, s_n_valid_trials))
174 | m_acc = np.mean(np.array(results['m_acc']), axis=0)
175 | s_acc = np.mean(np.array(results['s_acc']), axis=0)
176 |
177 | # plotting learning curve for total mean
178 | plot_path = join(savings_dir, learning_type + '_learning_curve')
179 | if learning_type is 'dl':
180 | # title = '{} learning curve\n({}$\pm${} samples, {}$\pm${} validation ' \
181 | # 'samples)'.format('average', m_n_trials, s_n_trials,
182 | # m_n_valid_trials[0], s_n_valid_trials[0])
183 | title = 'totale esempi per soggetto: {}$\pm${}; esempi di validazione: ' \
184 | '{}$\pm${}'.format(m_n_trials, s_n_trials,
185 | m_n_valid_trials[0], s_n_valid_trials[0])
186 | else:
187 | # title = '{} learning curve\n({}$\pm${} samples)'.format(
188 | # 'average', m_n_trials, s_n_trials)
189 | title = 'totale esempi per soggetto: {}$\pm${}'.format(
190 | m_n_trials, s_n_trials)
191 | x_tick_labels = ['{}$\pm${}\n({} fold)'.format(
192 | m_n_train_trials[idx], s_n_train_trials[idx], n_folds_list[idx])
193 | for idx in range(n_trainings)]
194 | plt.figure(dpi=100, figsize=fig_size, facecolor='w', edgecolor='k')
195 | plt.style.use('seaborn-whitegrid')
196 | plt.errorbar(x=n_folds_list, y=m_acc, yerr=s_acc,
197 | fmt='-.o', color='b', ecolor='r',
198 | linewidth=2, elinewidth=3, capsize=20, capthick=2)
199 | plt.xlabel('esempi di training', fontsize=fontsize_1)
200 | plt.ylabel('accuratezza (%)', fontsize=fontsize_1)
201 | plt.xticks(n_folds_list, labels=x_tick_labels, fontsize=fontsize_2)
202 | plt.yticks(fontsize=fontsize_2)
203 | plt.title(title, fontsize=fontsize_1)
204 |
205 | # in case of single_subject, -500,4000
206 | if fold_type is 'single_subject':
207 | if epoching is '-500_4000':
208 | plt.ylim(80, 100)
209 |
210 | # saving figure
211 | savefig(plot_path, bbox_inches='tight')
212 |
--------------------------------------------------------------------------------
/sub_routines/ml_cross_validation.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | from hgdecode.classes import CrossValidation
4 |
5 | datetime_dir = '/Users/davidemiani/OneDrive - Alma Mater Studiorum ' \
6 | 'Università di Bologna/TesiMagistrale_DavideMiani/' \
7 | 'results/hgdecode/ml/FBCSP_rLDA/2019-01-15_01-54-25'
8 | subj_dirs = os.listdir(datetime_dir)
9 | subj_dirs.sort()
10 | subj_dirs.remove('statistics')
11 | subj_dirs = [os.path.join(datetime_dir, x) for x in subj_dirs]
12 |
13 | name_to_start_codes = OrderedDict([('Right Hand', [1]),
14 | ('Left Hand', [2]),
15 | ('Rest', [3]),
16 | ('Feet', [4])])
17 |
18 | for subj_dir in subj_dirs:
19 | CrossValidation.cross_validate(subj_results_dir=subj_dir,
20 | label_names=name_to_start_codes)
21 |
--------------------------------------------------------------------------------
/sub_routines/t_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from csv import reader
3 | from numpy import array
4 | from scipy.stats import ttest_rel
5 | from hgdecode.utils import get_path
6 |
7 | """
8 | TRAINING 1
9 | """
10 | results_dir = None
11 | learning_type = 'dl'
12 | algorithm_or_model_name = None
13 | epoching = '-1000_1000'
14 | fold_type_1 = 'single_subject'
15 | n_folds_list = [12] # must be a list of integer
16 | deprecated = False
17 | balanced_folds = True
18 | folder_paths_1 = [
19 | get_path(
20 | results_dir=results_dir,
21 | learning_type=learning_type,
22 | algorithm_or_model_name=algorithm_or_model_name,
23 | epoching=epoching,
24 | fold_type=fold_type_1,
25 | n_folds=x,
26 | deprecated=deprecated,
27 | balanced_folds=balanced_folds
28 | )
29 | for x in n_folds_list
30 | ]
31 |
32 | """
33 | TRAINING 2
34 | """
35 | results_dir = None
36 | learning_type = 'ml'
37 | algorithm_or_model_name = None
38 | epoching = '-500_4000'
39 | fold_type_2 = 'single_subject'
40 | n_folds_list = [12] # must be a list of integer
41 | deprecated = False
42 | balanced_folds = True
43 | folder_paths_2 = [
44 | get_path(
45 | results_dir=results_dir,
46 | learning_type=learning_type,
47 | algorithm_or_model_name=algorithm_or_model_name,
48 | epoching=epoching,
49 | fold_type=fold_type_2,
50 | n_folds=x,
51 | deprecated=deprecated,
52 | balanced_folds=balanced_folds
53 | )
54 | for x in n_folds_list
55 | ]
56 |
57 | """
58 | T-TESTING
59 | """
60 | for training_1, training_2 in zip(folder_paths_1, folder_paths_2):
61 | # loading training_1 accuracies
62 | training_1_acc_csv_path = os.path.join(training_1,
63 | 'statistics', 'tables', 'acc.csv')
64 | with open(training_1_acc_csv_path) as f:
65 | training_1_csv = list(reader(f))
66 | if fold_type_1 == 'cross_subject':
67 | training_1_accs = array(list(map(float, training_1_csv[1][:-2])))
68 | else:
69 | training_1_accs = array([
70 | float(training_1_csv[x][-2]) for x in range(1, len(training_1_csv))
71 | ])
72 |
73 | # loading training_2 accuracies
74 | training_2_acc_csv_path = os.path.join(training_2,
75 | 'statistics', 'tables', 'acc.csv')
76 | with open(training_2_acc_csv_path) as f:
77 | training_2_csv = list(reader(f))
78 | if fold_type_2 == 'cross_subject':
79 | training_2_accs = array(list(map(float, training_2_csv[1][:-2])))
80 | else:
81 | training_2_accs = array([
82 | float(training_2_csv[x][-2]) for x in range(1, len(training_2_csv))
83 | ])
84 |
85 | # running t-test
86 | statistic, p_value = ttest_rel(training_1_accs, training_2_accs)
87 | print(p_value)
88 |
--------------------------------------------------------------------------------
/sub_routines/transfer_learning_curve.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from pylab import savefig
4 | from pickle import load
5 | from os.path import join
6 | from os.path import dirname
7 | from hgdecode.utils import listdir2
8 | from hgdecode.utils import touch_dir
9 | from hgdecode.utils import get_subj_str
10 | from hgdecode.utils import get_fold_str
11 | from hgdecode.utils import get_path
12 |
13 |
14 | def get_fold_number(folder_path, subj_str):
15 | return len(listdir2(join(folder_path, subj_str)))
16 |
17 |
18 | """
19 | SET HERE YOUR PARAMETERS
20 | """
21 | # to find file parameters
22 | results_dir = None
23 | learning_type = 'dl'
24 | algorithm_or_model_name = None
25 | epoching = '-1000_1000'
26 | fold_type = 'transfer_learning'
27 | train_size_list = [4, 8, 16, 32, 64, 128] # must be a list of integer
28 | deprecated = False
29 |
30 | fontsize_1 = 35
31 | fontsize_2 = 27.5
32 | fig_size = (22, 7.5)
33 |
34 | """
35 | GETTING PATHS
36 | """
37 | # trainings stuff
38 | folder_paths = [
39 | get_path(
40 | results_dir=results_dir,
41 | learning_type=learning_type,
42 | algorithm_or_model_name=algorithm_or_model_name,
43 | epoching=epoching,
44 | fold_type=fold_type,
45 | n_folds=x,
46 | deprecated=deprecated
47 | )
48 | for x in train_size_list
49 | ]
50 | n_trainings = len(folder_paths)
51 |
52 | # saving stuff
53 | savings_dir = join(dirname(dirname(folder_paths[0])), 'learning_curve')
54 | touch_dir(savings_dir)
55 |
56 | """
57 | SUBJECTS STUFF
58 | """
59 | # subject stuff
60 | n_trials_list = [480, 973, 1040, 1057, 880, 1040, 1040, 814, 1040, 1040,
61 | 1040, 1040, 950, 1040]
62 | n_subjects = len(n_trials_list)
63 | subj_str_list = [get_subj_str(x + 1) for x in range(n_subjects)]
64 |
65 | """
66 | COMPUTATION STARTS HERE
67 | """
68 | # pre-allocating results dictionary
69 | results = {
70 | 'n_folds': [],
71 | 'n_trials': [],
72 | 'n_train_trials': [],
73 | 'n_valid_trials': [],
74 | 'n_test_trials': [],
75 | 'perc_train_trials': [],
76 | 'perc_valid_trials': [],
77 | 'perc_test_trials': [],
78 | 'm_acc': [],
79 | 's_acc': []
80 | }
81 |
82 | # cycling on subject
83 | for subj, current_n_trials in zip(subj_str_list, n_trials_list):
84 | n_folds = []
85 | n_trials = []
86 | n_train_trials = []
87 | n_valid_trials = []
88 | n_test_trials = []
89 | perc_train_trials = []
90 | perc_valid_trials = []
91 | perc_test_trials = []
92 | m_acc = []
93 | s_acc = []
94 |
95 | # cycling on all possible fold splits
96 | for idx, current_train_size in enumerate(train_size_list):
97 | n_folds.append(get_fold_number(folder_paths[idx], subj))
98 | n_trials.append(current_n_trials)
99 | n_valid_trials.append(int(np.floor(n_trials[idx] * 0.1)))
100 | n_train_trials.append(current_train_size)
101 | n_test_trials.append(n_trials[idx] - n_train_trials[idx] -
102 | n_valid_trials[idx])
103 | perc_train_trials.append(
104 | np.round(n_train_trials[idx] / n_trials[idx] * 100, 1))
105 | perc_valid_trials.append(
106 | np.round(n_valid_trials[idx] / n_trials[idx] * 100, 1))
107 | perc_test_trials.append(
108 | np.round(100 - perc_valid_trials[idx] - perc_train_trials[idx], 1))
109 |
110 | # cycling on folds
111 | folds_acc = []
112 | for fold_str in [get_fold_str(x + 1) for x in range(n_folds[idx])]:
113 | file_path = join(folder_paths[idx],
114 | subj, fold_str, 'fold_stats.pickle')
115 | with open(file_path, 'rb') as f:
116 | fold_stats = load(f)
117 | folds_acc.append(fold_stats['test']['acc'])
118 | m_acc.append(np.mean(folds_acc) * 100)
119 | s_acc.append(np.std(folds_acc) * 100)
120 |
121 | # assigning results for this subject
122 | results['n_folds'].append(n_folds)
123 | results['n_trials'].append(n_trials)
124 | results['n_train_trials'].append(n_train_trials)
125 | results['n_valid_trials'].append(n_valid_trials)
126 | results['n_test_trials'].append(n_test_trials)
127 | results['perc_train_trials'].append(perc_train_trials)
128 | results['perc_valid_trials'].append(perc_valid_trials)
129 | results['perc_test_trials'].append(perc_test_trials)
130 | results['m_acc'].append(m_acc)
131 | results['s_acc'].append(s_acc)
132 |
133 | # plotting learning curve for this subject
134 | m_acc = np.array(m_acc)
135 | s_acc = np.array(s_acc)
136 | plot_path = join(savings_dir, subj)
137 | if learning_type is 'dl':
138 | title = '{} transfer learning curve\n'.format(subj) + \
139 | '({} samples, {} validation samples)'.format(
140 | n_trials[0], n_valid_trials[0])
141 | else:
142 | title = '{} transfer learning curve\n({} samples)'.format(subj,
143 | n_trials[0])
144 | x_tick_labels = ['{}\n({} folds)'.format(trials, folds)
145 | for trials, folds in zip(n_train_trials, n_folds)]
146 | plt.figure(dpi=100, figsize=(12.8, 7.2), facecolor='w', edgecolor='k')
147 | plt.style.use('seaborn-whitegrid')
148 | plt.errorbar(x=[2, 4, 6, 8, 10, 12], y=m_acc, yerr=s_acc,
149 | fmt='-.o', color='b', ecolor='r',
150 | linewidth=2, elinewidth=3, capsize=20, capthick=2)
151 | plt.xlabel('training samples', fontsize=25)
152 | plt.ylabel('accuracy (%)', fontsize=25)
153 | plt.xticks([2, 4, 6, 8, 10, 12], labels=x_tick_labels, fontsize=20)
154 | plt.yticks(fontsize=20)
155 | plt.title(title, fontsize=25)
156 |
157 | # saving figure
158 | savefig(plot_path, bbox_inches='tight')
159 |
160 | # getting data for last plot
161 | n_trials = np.array(results['n_trials'])
162 | n_train_trials = np.array(results['n_train_trials'])
163 | n_valid_trials = np.array(results['n_valid_trials'])
164 | n_folds = np.array(results['n_folds'])
165 |
166 | # averaging data
167 | m_n_trials = int(np.round(np.mean(n_trials, axis=0))[0].tolist())
168 | s_n_trials = int(np.round(np.std(n_trials, axis=0))[0].tolist())
169 | m_n_train_trials = train_size_list
170 | m_n_valid_trials = np.round(np.mean(n_valid_trials, axis=0)).tolist()
171 | s_n_valid_trials = np.round(np.std(n_valid_trials, axis=0)).tolist()
172 | m_n_valid_trials = list(map(int, m_n_valid_trials))
173 | s_n_valid_trials = list(map(int, s_n_valid_trials))
174 | m_n_folds = np.round(np.mean(n_folds, axis=0)).tolist()
175 | s_n_folds = np.round(np.std(n_folds, axis=0)).tolist()
176 | m_n_folds = list(map(int, m_n_folds))
177 | s_n_folds = list(map(int, s_n_folds))
178 | m_acc = np.mean(np.array(results['m_acc']), axis=0)
179 | s_acc = np.mean(np.array(results['s_acc']), axis=0)
180 |
181 | # plotting learning curve for total mean
182 | plot_path = join(savings_dir, 'transfer_learning_curve')
183 | if learning_type is 'dl':
184 | # title = '{} learning curve\n({}$\pm${} samples, {}$\pm${} validation ' \
185 | # 'samples)'.format('average', m_n_trials, s_n_trials,
186 | # m_n_valid_trials[0], s_n_valid_trials[0])
187 | title = 'totale esempi per soggetto: {}$\pm${}; esempi di validazione: ' \
188 | '{}$\pm${}'.format(m_n_trials, s_n_trials,
189 | m_n_valid_trials[0], s_n_valid_trials[0])
190 | else:
191 | title = '{} learning curve\n({}$\pm${} samples)'.format(
192 | 'average', m_n_trials, s_n_trials)
193 | x_tick_labels = ['{}\n({}$\pm${} fold)'.format(
194 | m_n_train_trials[idx], m_n_folds[idx], s_n_folds[idx])
195 | for idx in range(n_trainings)]
196 | plt.figure(dpi=100, figsize=fig_size, facecolor='w', edgecolor='k')
197 | plt.style.use('seaborn-whitegrid')
198 | plt.errorbar(x=[2, 4, 6, 8, 10, 12], y=m_acc, yerr=s_acc,
199 | fmt='-.o', color='b', ecolor='r',
200 | linewidth=2, elinewidth=3, capsize=20, capthick=2)
201 | plt.xlabel('esempi di training', fontsize=fontsize_1)
202 | plt.ylabel('accuratezza (%)', fontsize=fontsize_1)
203 | plt.xticks([2, 4, 6, 8, 10, 12], labels=x_tick_labels, fontsize=fontsize_2)
204 | plt.yticks(fontsize=fontsize_2)
205 | plt.title(title, fontsize=fontsize_1)
206 |
207 | # in case of single_subject, -500,4000
208 | if fold_type is 'single_subject':
209 | if epoching is '-500_4000':
210 | plt.ylim(80, 100)
211 |
212 | # saving figure
213 | savefig(plot_path, bbox_inches='tight')
214 |
--------------------------------------------------------------------------------
/sub_routines/transfer_learning_curve_frozen_layers.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join, dirname
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from pylab import savefig
6 | from csv import reader
7 | from hgdecode.utils import get_path, touch_dir
8 |
9 | """
10 | SET HERE YOUR PARAMETERS
11 | """
12 | ival = (-500, 4000)
13 | frozen_layers_list = [1, 2, 3, 4, 5, -1, -2, -3, -4, -5, 6]
14 |
15 | fontsize_1 = 35
16 | fontsize_2 = 27.5
17 | fig_size = (22, 7.5)
18 |
19 | """
20 | GETTING PATHS
21 | """
22 | folder_paths = [get_path(
23 | results_dir=None,
24 | learning_type='dl',
25 | algorithm_or_model_name=None,
26 | epoching=ival,
27 | fold_type='transfer_learning',
28 | n_folds=128,
29 | deprecated=False
30 | )]
31 |
32 | folder_paths += [
33 | get_path(
34 | results_dir=None,
35 | learning_type='dl',
36 | algorithm_or_model_name='DeepConvNet',
37 | epoching=ival,
38 | fold_type='transfer_learning_frozen',
39 | n_folds=x,
40 | deprecated=False
41 | )
42 | for x in frozen_layers_list
43 | ]
44 |
45 | # getting file_path
46 | csv_paths = [
47 | os.path.join(x, 'statistics', 'tables', 'acc.csv')
48 | for x in folder_paths
49 | ]
50 |
51 | """
52 | SETTING PATHS
53 | """
54 | # saving stuff
55 | savings_dir = join(dirname(
56 | dirname(dirname(folder_paths[1]))), 'learning_curve')
57 | touch_dir(savings_dir)
58 |
59 | """
60 | GETTING DATA FROM CSV FILES
61 | """
62 | subj_data = []
63 | mean_data = []
64 | stdd_data = []
65 | for idx, csv_path in enumerate(csv_paths):
66 | with open(csv_path) as f:
67 | csv = list(reader(f))
68 | csv = csv[1:]
69 | csv = [
70 | list(map(float, csv[x]))
71 | for x in range(len(csv))
72 | ]
73 | temp_data = []
74 | for csv_line in csv:
75 | temp_data.append(csv_line[-2])
76 | subj_data.append(temp_data)
77 | mean_data.append(np.mean(temp_data))
78 | stdd_data.append(np.std(temp_data))
79 |
80 | """
81 | DEMIXING STUFF
82 | """
83 | pos_layers_m = np.round(np.array(mean_data[0:6] + [mean_data[-1]]) * 100, 1)
84 | neg_layers_m = np.round(np.array([mean_data[0]] + mean_data[-6:]) * 100, 1)
85 | pos_layers_s = np.round(np.array(stdd_data[0:6] + [stdd_data[-1]]) * 100, 1)
86 | neg_layers_s = np.round(np.array([stdd_data[0]] + stdd_data[-6:]) * 100, 1)
87 |
88 | """
89 | PLOTTING POS
90 | """
91 | plt.figure(dpi=100, figsize=fig_size, facecolor='w', edgecolor='k')
92 | plt.style.use('seaborn-whitegrid')
93 | plt.errorbar(x=[0, 1, 2, 3, 4, 5, 6],
94 | y=pos_layers_m,
95 | yerr=pos_layers_s,
96 | fmt='-.o', color='b', ecolor='r',
97 | linewidth=2, elinewidth=3, capsize=20, capthick=2)
98 | plt.xlabel('indice di congelamento',
99 | fontsize=fontsize_1)
100 | plt.ylabel('accuratezza (%)', fontsize=fontsize_1)
101 | plt.xticks(fontsize=fontsize_2)
102 | plt.yticks(fontsize=fontsize_2)
103 | plt.title('esempi di training: 128', fontsize=fontsize_1)
104 | savefig(
105 | join(savings_dir, 'frozen_layers_learning_curve_1'), bbox_inches='tight')
106 |
107 | """
108 | PLOTTING NEG
109 | """
110 | plt.figure(dpi=100, figsize=fig_size, facecolor='w', edgecolor='k')
111 | plt.style.use('seaborn-whitegrid')
112 | plt.errorbar(x=[0, -1, -2, -3, -4, -5, -6],
113 | y=neg_layers_m,
114 | yerr=neg_layers_s,
115 | fmt='-.o', color='b', ecolor='r',
116 | linewidth=2, elinewidth=3, capsize=20, capthick=2)
117 | plt.xlabel('indice di congelamento',
118 | fontsize=fontsize_1)
119 | plt.ylabel('accuratezza (%)', fontsize=fontsize_1)
120 | plt.xticks(fontsize=fontsize_2)
121 | plt.yticks(fontsize=fontsize_2)
122 | plt.title('esempi di training: 128', fontsize=fontsize_1)
123 | savefig(
124 | join(savings_dir, 'frozen_layers_learning_curve_2'), bbox_inches='tight')
125 |
--------------------------------------------------------------------------------
/transfer_learning.py:
--------------------------------------------------------------------------------
1 | from os import getcwd
2 | from numpy import ceil
3 | from os.path import join
4 | from os.path import dirname
5 | from collections import OrderedDict
6 | from numpy.random import RandomState
7 | from hgdecode.utils import get_path
8 | from hgdecode.utils import create_log
9 | from hgdecode.utils import print_manager
10 | from hgdecode.utils import clear_all_models
11 | from hgdecode.loaders import dl_loader
12 | from hgdecode.classes import CrossValidation
13 | from hgdecode.experiments import DLExperiment
14 | from keras import backend as K
15 |
16 | """
17 | SETTING PARAMETERS
18 | ------------------
19 | """
20 |
21 | # setting model_name and validation_frac
22 | model_name = 'DeepConvNet' # Schirrmeister: 'DeepConvNet' or 'ShallowNet'
23 |
24 | # setting channel_names
25 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4',
26 | 'CP5', 'CP1', 'CP2', 'CP6',
27 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6',
28 | 'CP3', 'CPz', 'CP4',
29 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
30 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h',
31 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
32 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h',
33 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
34 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h']
35 |
36 | # setting data_dir & results_dir
37 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma')
38 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode')
39 |
40 | # setting name_to_start_codes
41 | name_to_start_codes = OrderedDict([('Right Hand', [1]),
42 | ('Left Hand', [2]),
43 | ('Rest', [3]),
44 | ('Feet', [4])])
45 |
46 | # setting subject_ids
47 | subject_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)
48 |
49 | # setting random_state
50 | random_state = RandomState(1234)
51 |
52 | # setting fold_size: this will be the number of trials for training,
53 | # so it must be multiple of 4
54 | fold_size = 8 # must be integer
55 | validation_frac = 0.1
56 |
57 | # setting frozen_layers
58 | layers_to_freeze = 0 # can be between -6 and 6 for DeepConvNet
59 |
60 | # other hyper-parameters
61 | dropout_rate = 0.6
62 | learning_rate = 2 * 1e-5
63 | epochs = 100
64 | ival = (-1000, 1000)
65 |
66 | """
67 | GETTING CROSS-SUBJECT MODELS DIR PATH
68 | -------------------------------------
69 | """
70 | # setting cross_subj_dir_path: data from cross-subj computation are stored here
71 | learning_type = 'dl'
72 | algorithm_or_model_name = None
73 | epoching = ival
74 | fold_type = 'cross_subject'
75 | n_folds = None
76 | deprecated = False
77 | cross_subj_dir_path = get_path(
78 | results_dir=dirname(results_dir),
79 | learning_type=learning_type,
80 | algorithm_or_model_name=algorithm_or_model_name,
81 | epoching=epoching,
82 | fold_type=fold_type,
83 | n_folds=n_folds,
84 | deprecated=deprecated
85 | )
86 |
87 | """
88 | COMPUTATION
89 | -----------
90 | """
91 | for subject_id in subject_ids:
92 | # creating a log object
93 | subj_results_dir = create_log(
94 | results_dir=results_dir,
95 | learning_type='dl',
96 | algorithm_or_model_name=model_name,
97 | subject_id=subject_id,
98 | output_on_file=False,
99 | use_last_result_directory=False
100 | )
101 |
102 | # loading epoched signal
103 | epo = dl_loader(
104 | data_dir=data_dir,
105 | name_to_start_codes=name_to_start_codes,
106 | channel_names=channel_names,
107 | subject_id=subject_id,
108 | resampling_freq=250, # Schirrmeister: 250
109 | clean_ival_ms=ival, # Schirrmeister: (0, 4000)
110 | epoch_ival_ms=ival, # Schirrmeister: (-500, 4000)
111 | train_test_split=True, # Schirrmeister: True
112 | clean_on_all_channels=False # Schirrmeister: True
113 | )
114 |
115 | # if fold_size is not a multiple of 4, putting it to the nearest
116 | fold_size = int(ceil(fold_size / 4) * 4)
117 |
118 | # computing batch_size to be...
119 | if fold_size <= 64:
120 | batch_size = fold_size
121 | else:
122 | batch_size = 64
123 |
124 | # I don't think this is a good idea:
125 | # validation_size is equal to fold_size
126 | # validation_size = fold_size
127 |
128 | # creating CrossValidation class instance
129 | cross_validation = CrossValidation(
130 | X=epo.X,
131 | y=epo.y,
132 | fold_size=fold_size,
133 | # validation_size=validation_size,
134 | validation_frac=validation_frac,
135 | random_state=random_state, shuffle=True,
136 | swap_train_test=True,
137 | )
138 | cross_validation.balance_train_set(train_size=fold_size)
139 |
140 | # pre-allocating experiment
141 | exp = None
142 |
143 | # cycling on folds for cross validation
144 | for fold_idx, current_fold in enumerate(cross_validation.folds):
145 | # clearing TF graph (https://github.com/keras-team/keras/issues/3579)
146 | print_manager('CLEARING KERAS BACKEND', print_style='double-dashed')
147 | K.clear_session()
148 | print_manager(print_style='last', bottom_return=1)
149 |
150 | # printing fold information
151 | print_manager(
152 | 'SUBJECT {}, FOLD {}'.format(subject_id, fold_idx + 1),
153 | print_style='double-dashed'
154 | )
155 | cross_validation.print_fold_classes(fold_idx)
156 | print_manager(print_style='last', bottom_return=1)
157 |
158 | # creating EEGDataset for current fold
159 | dataset = cross_validation.create_dataset(fold=current_fold)
160 |
161 | # creating experiment instance
162 | exp = DLExperiment(
163 | # non-default inputs
164 | dataset=dataset,
165 | model_name=model_name,
166 | results_dir=results_dir,
167 | subj_results_dir=subj_results_dir,
168 | name_to_start_codes=name_to_start_codes,
169 | random_state=random_state,
170 | fold_idx=fold_idx,
171 |
172 | # hyperparameters
173 | dropout_rate=dropout_rate, # Schirrmeister: 0.5
174 | learning_rate=learning_rate, # Schirrmeister: ?
175 | batch_size=batch_size, # Schirrmeister: 512
176 | epochs=epochs, # Schirrmeister: ?
177 | early_stopping=False, # Schirrmeister: ?
178 | monitor='val_acc', # Schirrmeister: ?
179 | min_delta=0.0001, # Schirrmeister: ?
180 | patience=5, # Schirrmeister: ?
181 | loss='categorical_crossentropy', # Schirrmeister: ad hoc
182 | optimizer='Adam', # Schirrmeister: Adam
183 | shuffle=True, # Schirrmeister: ?
184 | crop_sample_size=None, # Schirrmeister: 1125
185 | crop_step=None, # Schirrmeister: 1
186 |
187 | # other parameters
188 | subject_id=subject_id,
189 | data_generator=False, # Schirrmeister: True
190 | save_model_at_each_epoch=False
191 | )
192 |
193 | # loading model weights from cross-subject pre-trained model
194 | exp.prepare_for_transfer_learning(
195 | cross_subj_dir_path=cross_subj_dir_path,
196 | subject_id=subject_id,
197 | train_anyway=True
198 | )
199 |
200 | # freezing layers
201 | exp.freeze_layers(layers_to_freeze=layers_to_freeze)
202 |
203 | # training
204 | exp.train()
205 |
206 | # computing cross-validation
207 | if exp is not None:
208 | cross_validation.cross_validate(
209 | subj_results_dir=exp.subj_results_dir,
210 | label_names=name_to_start_codes)
211 |
212 | # clearing all models (they are not useful once we have results)
213 | clear_all_models(subj_results_dir)
214 |
--------------------------------------------------------------------------------