├── .DS_Store ├── .gitignore ├── .idea ├── codeStyles │ └── codeStyleConfig.xml ├── dictionaries │ ├── dave.xml │ └── davidemiani.xml ├── encodings.xml ├── hgdecode.iml ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml ├── other.xml ├── vcs.xml └── workspace.xml ├── README.md ├── __init__.py ├── dl_main.py ├── dl_main_cross_subject.py ├── hgdecode ├── classes.py ├── experiments.py ├── fbcsprlda.py ├── lda.py ├── loaders.py ├── models.py ├── signalproc.py └── utils.py ├── ml_main.py ├── ml_main_cross_subject.py ├── schirrmeister_main.py ├── sub_routines ├── dl_cross_validation.py ├── latex_tabular_parser.py ├── latex_tabular_parser_cross_subj.py ├── latex_tabular_parser_transfer_learning.py ├── latex_tabular_parser_transfer_learning_frozen_layers.py ├── learning_curve.py ├── ml_cross_validation.py ├── t_test.py ├── transfer_learning_curve.py └── transfer_learning_curve_frozen_layers.py └── transfer_learning.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidemiani/hgdecode/359715c35035b7a37c0767e221a0ffa73254c7be/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | 4 | \.idea/codeStyles/codeStyleConfig\.xml 5 | -------------------------------------------------------------------------------- /.idea/codeStyles/codeStyleConfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /.idea/dictionaries/dave.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | accs 5 | arxiv 6 | asctime 7 | bandpassed 8 | bandpassing 9 | bbci 10 | bestfilterband 11 | bincsp 12 | butterworth 13 | conv 14 | convolutional 15 | crossentropy 16 | datasets 17 | depthwise 18 | dmdl 19 | eigenvectors 20 | epoched 21 | epoching 22 | fbcsp 23 | filt 24 | filterband 25 | filterbands 26 | filterbank 27 | filterpair 28 | filterwise 29 | hgdecode 30 | hyperparameter 31 | hyperparameters 32 | inds 33 | ival 34 | keras 35 | levelname 36 | miani 37 | microvolt 38 | multiclass 39 | ndarray 40 | npint 41 | overcomplete 42 | pointwise 43 | preds 44 | regularizations 45 | regularizer 46 | resample 47 | resampling 48 | robintibor 49 | schirrmeister 50 | scipy 51 | sfreq 52 | softmax 53 | ssvep 54 | waytowich 55 | 56 | 57 | -------------------------------------------------------------------------------- /.idea/dictionaries/davidemiani.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | algo 5 | barstr 6 | dataset 7 | davidemiani 8 | filtfilt 9 | frac 10 | hline 11 | idxs 12 | ipykernel 13 | isatty 14 | magistrale 15 | numdigits 16 | nyquist 17 | occurrs 18 | perc 19 | prec 20 | pred 21 | preprocess 22 | prog 23 | pval 24 | rowcolor 25 | seaborn 26 | stdd 27 | studiorum 28 | subjs 29 | tesi 30 | textbf 31 | totale 32 | universita 33 | whitegrid 34 | xticks 35 | 36 | 37 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/hgdecode.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | 12 | 14 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 24 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 10 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hgdecode 2 | Implements High-Gamma dataset decoding using Filter Bank Common Spatial Pattern with rLDA classification and Neural Networks. 3 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidemiani/hgdecode/359715c35035b7a37c0767e221a0ffa73254c7be/__init__.py -------------------------------------------------------------------------------- /dl_main.py: -------------------------------------------------------------------------------- 1 | from os import getcwd 2 | from os.path import join 3 | from os.path import dirname 4 | from collections import OrderedDict 5 | from numpy.random import RandomState 6 | from hgdecode.utils import create_log 7 | from hgdecode.utils import print_manager 8 | from hgdecode.loaders import dl_loader 9 | from hgdecode.classes import CrossValidation 10 | from hgdecode.experiments import DLExperiment 11 | from keras import backend as K 12 | 13 | """ 14 | SETTING PARAMETERS 15 | ------------------ 16 | In the following, you have to set / modify all the parameters to use for 17 | further computation. 18 | 19 | Parameters 20 | ---------- 21 | channel_names : list 22 | Channels to use for computation 23 | data_dir : str 24 | Path to the directory that contains dataset 25 | model_name : str 26 | Name of the Deep Learning model 27 | name_to_start_codes : OrderedDict 28 | All possible classes names and codes in an ordered dict format 29 | random_seed : rng seed 30 | Seed random for all random calls 31 | results_dir : str 32 | Path to the directory that will contain the results 33 | subject_ids : tuple 34 | All the subject ids in a tuple; add or remove subjects to run the 35 | algorithm for them or not 36 | """ 37 | # setting model_name and validation_frac 38 | model_name = 'DeepConvNet' # Schirrmeister: 'DeepConvNet' or 'ShallowNet' 39 | 40 | # setting channel_names 41 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 42 | 'CP5', 'CP1', 'CP2', 'CP6', 43 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 44 | 'CP3', 'CPz', 'CP4', 45 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 46 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h', 47 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 48 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h', 49 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 50 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h'] 51 | 52 | # setting data_dir & results_dir 53 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma') 54 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode') 55 | 56 | # setting name_to_start_codes 57 | name_to_start_codes = OrderedDict([('Right Hand', [1]), 58 | ('Left Hand', [2]), 59 | ('Rest', [3]), 60 | ('Feet', [4])]) 61 | 62 | # setting random_state 63 | random_state = RandomState(1234) 64 | 65 | # real useful hyperparameters 66 | standardize_mode = 2 67 | subject_ids = (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14) 68 | ival = (-500, 4000) 69 | n_folds = 12 70 | fold_size = None 71 | swap_train_test = False 72 | learning_rate = 1e-4 73 | dropout_rate = 0.5 74 | batch_size = 64 75 | epochs = 1000 76 | 77 | """ 78 | MAIN CYCLE 79 | ---------- 80 | For each subject, a new log will be created and the specific dataset loaded; 81 | this dataset will be used to create an instance of the experiment; then the 82 | experiment will be run. You can of course change all the experiment inputs 83 | to obtain different results. 84 | """ 85 | for subject_id in subject_ids: 86 | # creating a log object 87 | subj_results_dir = create_log( 88 | results_dir=results_dir, 89 | learning_type='dl', 90 | algorithm_or_model_name=model_name, 91 | subject_id=subject_id, 92 | output_on_file=False 93 | ) 94 | 95 | # loading epoched signal 96 | epo = dl_loader( 97 | data_dir=data_dir, 98 | name_to_start_codes=name_to_start_codes, 99 | channel_names=channel_names, 100 | subject_id=subject_id, 101 | resampling_freq=250, # Schirrmeister: 250 102 | clean_ival_ms=ival, # Schirrmeister: (0, 4000) 103 | epoch_ival_ms=ival, # Schirrmeister: (-500, 4000) 104 | train_test_split=True, # Schirrmeister: True 105 | clean_on_all_channels=False, # Schirrmeister: True 106 | standardize_mode=standardize_mode # Schirrmeister: 2 107 | ) 108 | 109 | # creating CrossValidation class instance 110 | cv = CrossValidation( 111 | X=epo.X, 112 | y=epo.y, 113 | n_folds=n_folds, 114 | fold_size=fold_size, 115 | validation_frac=0.1, 116 | random_state=random_state, 117 | shuffle=True, 118 | swap_train_test=swap_train_test 119 | ) 120 | if n_folds is None: 121 | cv.balance_train_set(train_size=fold_size) 122 | 123 | # pre-allocating experiment 124 | exp = None 125 | 126 | # cycling on folds for cross validation 127 | for fold_idx, current_fold in enumerate(cv.folds): 128 | # clearing TF graph (https://github.com/keras-team/keras/issues/3579) 129 | print_manager('CLEARING KERAS BACKEND', print_style='double-dashed') 130 | K.clear_session() 131 | print_manager(print_style='last', bottom_return=1) 132 | 133 | # printing fold information 134 | print_manager( 135 | 'SUBJECT {}, FOLD {}'.format(subject_id, fold_idx + 1), 136 | print_style='double-dashed' 137 | ) 138 | cv.print_fold_classes(fold_idx) 139 | print_manager(print_style='last', bottom_return=1) 140 | 141 | # creating EEGDataset for current fold 142 | dataset = cv.create_dataset(fold=current_fold) 143 | 144 | # creating experiment instance 145 | exp = DLExperiment( 146 | # non-default inputs 147 | dataset=dataset, 148 | model_name=model_name, 149 | results_dir=results_dir, 150 | subj_results_dir=subj_results_dir, 151 | name_to_start_codes=name_to_start_codes, 152 | random_state=random_state, 153 | fold_idx=fold_idx, 154 | 155 | # hyperparameters 156 | dropout_rate=dropout_rate, # Schirrmeister: 0.5 157 | learning_rate=learning_rate, # Schirrmeister: ? 158 | batch_size=batch_size, # Schirrmeister: 512 159 | epochs=epochs, # Schirrmeister: ? 160 | early_stopping=False, # Schirrmeister: ? 161 | monitor='val_acc', # Schirrmeister: ? 162 | min_delta=0.0001, # Schirrmeister: ? 163 | patience=5, # Schirrmeister: ? 164 | loss='categorical_crossentropy', # Schirrmeister: ad hoc 165 | optimizer='Adam', # Schirrmeister: Adam 166 | shuffle=True, # Schirrmeister: ? 167 | crop_sample_size=None, # Schirrmeister: 1125 168 | crop_step=None, # Schirrmeister: 1 169 | 170 | # other parameters 171 | subject_id=subject_id, 172 | data_generator=False, # Schirrmeister: True 173 | save_model_at_each_epoch=False 174 | ) 175 | 176 | # training 177 | exp.train() 178 | 179 | if exp is not None: 180 | # computing cross-validation 181 | cv.cross_validate(subj_results_dir=subj_results_dir, 182 | label_names=name_to_start_codes) 183 | -------------------------------------------------------------------------------- /dl_main_cross_subject.py: -------------------------------------------------------------------------------- 1 | from os import getcwd 2 | from os.path import join 3 | from os.path import dirname 4 | from collections import OrderedDict 5 | from numpy.random import RandomState 6 | from hgdecode.utils import create_log 7 | from hgdecode.utils import print_manager 8 | from hgdecode.loaders import CrossSubject 9 | from hgdecode.classes import CrossValidation 10 | from hgdecode.experiments import DLExperiment 11 | from keras import backend as K 12 | 13 | """ 14 | SETTING PARAMETERS 15 | Here you can set whatever parameter you want 16 | """ 17 | # setting model_name 18 | model_name = 'DeepConvNet' 19 | 20 | # setting channel_names 21 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 22 | 'CP5', 'CP1', 'CP2', 'CP6', 23 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 24 | 'CP3', 'CPz', 'CP4', 25 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 26 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h', 27 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 28 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h', 29 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 30 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h'] 31 | 32 | # setting data_dir & results_dir 33 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma') 34 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode') 35 | 36 | # setting name_to_start_codes 37 | name_to_start_codes = OrderedDict([('Right Hand', [1]), 38 | ('Left Hand', [2]), 39 | ('Rest', [3]), 40 | ('Feet', [4])]) 41 | 42 | # setting random_state 43 | random_state = RandomState(1234) 44 | 45 | # setting subject_ids 46 | subject_ids = (1, 2) # , 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14) 47 | 48 | # setting hyperparameters 49 | ival = (-500, 4000) 50 | standardize_mode = 2 51 | learning_rate = 1 * 1e-4 52 | dropout_rate = 0.5 53 | batch_size = 32 54 | epochs = 800 55 | 56 | """ 57 | STARTING LOADING ROUTINE & COMPUTATION 58 | Here you can change some parameter in function calls as well 59 | """ 60 | # creating a log object 61 | subj_results_dir = create_log( 62 | results_dir=results_dir, 63 | learning_type='dl', 64 | algorithm_or_model_name=model_name, 65 | subject_id='subj_cross', 66 | output_on_file=False 67 | ) 68 | 69 | # creating a cross-subject object for cross-subject validation 70 | cross_obj = CrossSubject(data_dir=data_dir, 71 | subject_ids=subject_ids, 72 | channel_names=channel_names, 73 | name_to_start_codes=name_to_start_codes, 74 | random_state=random_state, 75 | validation_frac=0.1, 76 | resampling_freq=250, 77 | train_test_split=True, 78 | clean_ival_ms=ival, 79 | epoch_ival_ms=ival, 80 | clean_on_all_channels=False) 81 | 82 | # parsing all cnt data to epoched (we no more need cnt) 83 | cross_obj.parser(output_format='epo', parsing_type=1) 84 | 85 | # pre-allocating experiment 86 | exp = None 87 | 88 | # cycling on subject leaved apart 89 | for leave_subj in subject_ids: 90 | # clearing TF graph (https://github.com/keras-team/keras/issues/3579) 91 | print_manager('CLEARING KERAS BACKEND', print_style='double-dashed') 92 | K.clear_session() 93 | print_manager(print_style='last', bottom_return=1) 94 | 95 | # creating dataset for this "all but" fold 96 | cross_obj.parser(output_format='EEGDataset', 97 | leave_subj=leave_subj, 98 | parsing_type=1) 99 | 100 | # creating experiment instance 101 | exp = DLExperiment( 102 | # non-default inputs 103 | dataset=cross_obj.fold_data, 104 | model_name=model_name, 105 | results_dir=results_dir, 106 | name_to_start_codes=name_to_start_codes, 107 | random_state=random_state, 108 | fold_idx=leave_subj - 1, 109 | 110 | # hyperparameters 111 | dropout_rate=dropout_rate, # Schirrmeister: 0.5 112 | learning_rate=learning_rate, # Schirrmeister: ? 113 | batch_size=batch_size, # Schirrmeister: 512 114 | epochs=epochs, # Schirrmeister: ? 115 | early_stopping=False, # Schirrmeister: ? 116 | monitor='val_acc', # Schirrmeister: ? 117 | min_delta=0.0001, # Schirrmeister: ? 118 | patience=5, # Schirrmeister: ? 119 | loss='categorical_crossentropy', # Schirrmeister: ad hoc 120 | optimizer='Adam', # Schirrmeister: Adam 121 | shuffle=True, # Schirrmeister: ? 122 | crop_sample_size=None, # Schirrmeister: 1125 123 | crop_step=None, # Schirrmeister: 1 124 | 125 | # other parameters 126 | subject_id='_cross', 127 | data_generator=False, # Schirrmeister: True 128 | save_model_at_each_epoch=False, 129 | subj_results_dir=subj_results_dir 130 | ) 131 | 132 | # running training 133 | exp.train() 134 | 135 | # at the very end, running cross-validation 136 | if exp is not None: 137 | CrossValidation.cross_validate(subj_results_dir=exp.subj_results_dir, 138 | label_names=name_to_start_codes) 139 | -------------------------------------------------------------------------------- /hgdecode/experiments.py: -------------------------------------------------------------------------------- 1 | # General modules 2 | import numpy as np 3 | from numpy import arange 4 | from numpy import setdiff1d 5 | from numpy import int as npint 6 | from pickle import load 7 | from os.path import join 8 | from os.path import dirname 9 | from itertools import combinations 10 | from hgdecode.utils import touch_dir 11 | from hgdecode.utils import my_formatter 12 | from hgdecode.utils import print_manager 13 | from sklearn.metrics import confusion_matrix 14 | from multiprocessing import cpu_count 15 | 16 | # Deep Learning 17 | from hgdecode import models 18 | from keras import optimizers 19 | from keras.callbacks import CSVLogger 20 | from keras.callbacks import EarlyStopping 21 | from keras.callbacks import ModelCheckpoint 22 | from hgdecode.classes import MetricsTracker 23 | from hgdecode.classes import EEGDataGenerator 24 | 25 | # Machine Learning 26 | from hgdecode.classes import FilterBank 27 | from hgdecode.fbcsprlda import BinaryFBCSP 28 | from hgdecode.fbcsprlda import FBCSP 29 | from hgdecode.fbcsprlda import MultiClassWeightedVoting 30 | from braindecode.datautil.iterators import get_balanced_batches 31 | 32 | 33 | class FBCSPrLDAExperiment(object): 34 | """ 35 | A Filter Bank Common Spatial Patterns with rLDA 36 | classification Experiment. 37 | 38 | Parameters 39 | ---------- 40 | cnt : RawArray 41 | The continuous train recordings with events in info['events'] 42 | clean_trial_mask : bool array 43 | Bool array containing information about valid/invalid trials 44 | name_to_start_codes: dict 45 | Dictionary mapping class names to marker numbers, e.g. 46 | {'1 - Correct': [31], '2 - Error': [32]} 47 | epoch_ival_ms : sequence of 2 floats 48 | The start and end of the trial in milliseconds with respect to 49 | the markers. 50 | min_freq : int or list or tuple 51 | The minimum frequency/ies of the filterbank/s. 52 | max_freq : int or list or tuple 53 | The maximum frequency/ies of the filterbank/s. 54 | window : int or list or tuple 55 | Bandwidths of filters in filterbank/s. 56 | overlap : int or list or tuple 57 | Overlap frequencies between filters in filterbank/s. 58 | filt_order : int 59 | The filter order of the butterworth filter which computes the 60 | filterbands. 61 | n_folds : int 62 | How many folds. Also determines size of the test fold, e.g. 63 | 5 folds imply the test fold has 20% of the original data. 64 | n_top_bottom_csp_filters : int or None 65 | Number of top and bottom CSP filters to select from all computed 66 | filters. Top and bottom refers to CSP filters sorted by their 67 | eigenvalues. So a value of 3 here will lead to 6(!) filters. 68 | None means all filters. 69 | n_selected_filterbands : int or None 70 | Number of filterbands to select for the filterbank. 71 | Will be selected by the highest training accuracies. 72 | None means all filterbands. 73 | n_selected_features : int or None 74 | Number of features to select for the filterbank. 75 | Will be selected by an internal cross validation across feature 76 | subsets. 77 | None means all features. 78 | forward_steps : int 79 | Number of forward steps to make in the feature selection, 80 | before the next backward step. 81 | backward_steps : int 82 | Number of backward steps to make in the feature selection, 83 | before the next forward step. 84 | stop_when_no_improvement: bool 85 | Whether to stop the feature selection if the internal cross 86 | validation accuracy could not be improved after an epoch finished 87 | (epoch=given number of forward and backward steps). 88 | False implies always run until wanted number of features. 89 | shuffle: bool 90 | Whether to shuffle the clean trials before splitting them into 91 | folds. False implies folds are time-blocks, True implies folds are 92 | random mixes of trials of the entire file. 93 | """ 94 | 95 | def __init__(self, 96 | # signal-related inputs 97 | cnt, 98 | clean_trial_mask, 99 | name_to_start_codes, 100 | random_state, 101 | name_to_stop_codes=None, 102 | epoch_ival_ms=(-500, 4000), 103 | cross_subject_object=None, 104 | 105 | # bank filter-related inputs 106 | min_freq=0, 107 | max_freq=12, 108 | window=6, 109 | overlap=3, 110 | filt_order=3, 111 | 112 | # machine learning-related inputs 113 | n_folds=5, 114 | fold_file=None, 115 | n_top_bottom_csp_filters=None, 116 | n_selected_filterbands=None, 117 | n_selected_features=None, 118 | forward_steps=2, 119 | backward_steps=1, 120 | stop_when_no_improvement=False, 121 | shuffle=False, 122 | average_trial_covariance=True): 123 | # signal-related inputs 124 | self.cnt = cnt 125 | self.clean_trial_mask = clean_trial_mask 126 | self.epoch_ival_ms = epoch_ival_ms 127 | self.name_to_start_codes = name_to_start_codes 128 | self.name_to_stop_codes = name_to_stop_codes 129 | self.random_state = random_state 130 | if cross_subject_object is None: 131 | self.cross_subject_object = None 132 | self.cross_subject_computation = False 133 | else: 134 | self.cross_subject_object = cross_subject_object 135 | self.cross_subject_computation = True 136 | 137 | # bank filter-related inputs 138 | self.min_freq = min_freq 139 | self.max_freq = max_freq 140 | self.window = window 141 | self.overlap = overlap 142 | self.filt_order = filt_order 143 | 144 | # machine learning-related inputs 145 | self.n_folds = n_folds 146 | self.n_top_bottom_csp_filters = n_top_bottom_csp_filters 147 | self.n_selected_filterbands = n_selected_filterbands 148 | self.n_selected_features = n_selected_features 149 | self.forward_steps = forward_steps 150 | self.backward_steps = backward_steps 151 | self.stop_when_no_improvement = stop_when_no_improvement 152 | self.shuffle = shuffle 153 | self.average_trial_covariance = average_trial_covariance 154 | if fold_file is None: 155 | self.fold_file = None 156 | self.load_fold_from_file = False 157 | else: 158 | self.fold_file = fold_file 159 | self.load_fold_from_file = True 160 | 161 | # other fundamental properties (they will be filled in further 162 | # computational steps) 163 | self.filterbank_csp = None 164 | self.class_pairs = None 165 | self.folds = None 166 | self.binary_csp = None 167 | self.filterbands = None 168 | self.multi_class = None 169 | 170 | # computing other properties for further computation 171 | self.n_classes = len(self.name_to_start_codes) 172 | self.class_pairs = list(combinations(range(self.n_classes), 2)) 173 | self.n_trials = self.clean_trial_mask.astype(npint).sum() 174 | 175 | def create_filter_bank(self): 176 | self.filterbands = FilterBank( 177 | min_freq=self.min_freq, 178 | max_freq=self.max_freq, 179 | window=self.window, 180 | overlap=self.overlap 181 | ) 182 | 183 | def create_folds(self): 184 | if self.cross_subject_computation is True: 185 | # in case of cross-subject computation 186 | folds = [ 187 | arange( 188 | self.cross_subject_object.subject_indexes[x][0], 189 | self.cross_subject_object.subject_indexes[x][1] 190 | ) 191 | for x in range(len(self.cross_subject_object.subject_indexes)) 192 | ] 193 | self.n_folds = len(folds) 194 | self.folds = [ 195 | { 196 | 'train': setdiff1d(arange(self.n_trials), fold), 197 | 'test': fold 198 | } 199 | for fold in folds 200 | ] 201 | elif self.load_fold_from_file is True: 202 | # in case of pre-batched computation 203 | self.folds = np.load(self.fold_file)['folds'] 204 | elif self.n_folds == 0: 205 | self.n_folds = 1 206 | 207 | # creating schirrmeister fold 208 | all_idxs = np.array(range(len(self.clean_trial_mask))) 209 | self.folds = [ 210 | { 211 | 'train': all_idxs[:-160], 212 | 'test': all_idxs[-160:] 213 | } 214 | ] 215 | self.folds[0]['train'] = self.folds[0]['train'][ 216 | self.clean_trial_mask[:-160]] 217 | self.folds[0]['test'] = self.folds[0]['test'][ 218 | self.clean_trial_mask[-160:]] 219 | else: 220 | # getting pseudo-random folds 221 | folds = get_balanced_batches( 222 | n_trials=self.n_trials, 223 | rng=self.random_state, 224 | shuffle=self.shuffle, 225 | n_batches=self.n_folds 226 | ) 227 | self.folds = [ 228 | { 229 | 'train': setdiff1d(arange(self.n_trials), fold), 230 | 'test': fold 231 | } 232 | for fold in folds 233 | ] 234 | 235 | def run(self): 236 | # printing routine start 237 | print_manager( 238 | 'INIT TRAINING ROUTINE', 239 | 'double-dashed', 240 | ) 241 | 242 | # creating filter bank 243 | print_manager('Creating filter bank...') 244 | self.create_filter_bank() 245 | print_manager('DONE!!', bottom_return=1) 246 | 247 | # creating folds 248 | print_manager('Creating folds...') 249 | self.create_folds() 250 | print_manager('DONE!!', 'last') 251 | 252 | # running binary FBCSP 253 | print_manager("RUNNING BINARY FBCSP rLDA", 254 | 'double-dashed', 255 | top_return=1) 256 | self.binary_csp = BinaryFBCSP( 257 | cnt=self.cnt, 258 | clean_trial_mask=self.clean_trial_mask, 259 | filterbands=self.filterbands, 260 | filt_order=self.filt_order, 261 | folds=self.folds, 262 | class_pairs=self.class_pairs, 263 | epoch_ival_ms=self.epoch_ival_ms, 264 | n_filters=self.n_top_bottom_csp_filters, 265 | marker_def=self.name_to_start_codes, 266 | name_to_stop_codes=self.name_to_stop_codes, 267 | average_trial_covariance=self.average_trial_covariance 268 | ) 269 | self.binary_csp.run() 270 | 271 | # at the very end of the binary CSP experiment, running the real one 272 | print_manager("RUNNING FBCSP rLDA", 'double-dashed', top_return=1) 273 | self.filterbank_csp = FBCSP( 274 | binary_csp=self.binary_csp, 275 | n_features=self.n_selected_features, 276 | n_filterbands=self.n_selected_filterbands, 277 | forward_steps=self.forward_steps, 278 | backward_steps=self.backward_steps, 279 | stop_when_no_improvement=self.stop_when_no_improvement 280 | ) 281 | self.filterbank_csp.run() 282 | 283 | # and finally multiclass 284 | print_manager("RUNNING MULTICLASS", 'double-dashed', top_return=1) 285 | self.multi_class = MultiClassWeightedVoting( 286 | train_labels=self.binary_csp.train_labels_full_fold, 287 | test_labels=self.binary_csp.test_labels_full_fold, 288 | train_preds=self.filterbank_csp.train_pred_full_fold, 289 | test_preds=self.filterbank_csp.test_pred_full_fold, 290 | class_pairs=self.class_pairs) 291 | self.multi_class.run() 292 | print('\n') 293 | 294 | 295 | class DLExperiment(object): 296 | """ 297 | # TODO: a description for this class 298 | """ 299 | 300 | def __init__(self, 301 | # non-default inputs 302 | dataset, 303 | model_name, 304 | results_dir, 305 | subj_results_dir, 306 | name_to_start_codes, 307 | random_state, 308 | fold_idx, 309 | 310 | # hyperparameters 311 | dropout_rate=0.5, 312 | learning_rate=0.001, 313 | batch_size=128, 314 | epochs=10, 315 | early_stopping=False, 316 | monitor='val_acc', 317 | min_delta=0.0001, 318 | patience=5, 319 | loss='categorical_crossentropy', 320 | optimizer='Adam', 321 | shuffle='False', 322 | crop_sample_size=None, 323 | crop_step=None, 324 | 325 | # other parameters 326 | subject_id=1, 327 | data_generator=False, 328 | workers=cpu_count(), 329 | save_model_at_each_epoch=False): 330 | # non-default inputs 331 | self.dataset = dataset 332 | self.model_name = model_name 333 | self.results_dir = results_dir 334 | self.subj_results_dir = subj_results_dir 335 | self.datetime_results_dir = dirname(subj_results_dir) 336 | self.name_to_start_codes = name_to_start_codes 337 | self.random_state = random_state 338 | self.fold_idx = fold_idx 339 | 340 | # hyperparameters 341 | self.dropout_rate = dropout_rate 342 | self.learning_rate = learning_rate 343 | self.batch_size = batch_size 344 | self.epochs = epochs 345 | self.early_stopping = early_stopping 346 | self.monitor = monitor 347 | self.min_delta = min_delta 348 | self.patience = patience 349 | self.loss = loss 350 | self.optimizer = optimizer 351 | self.shuffle = shuffle 352 | if crop_sample_size is None: 353 | self.crop_sample_size = self.n_samples 354 | self.crop_step = 1 355 | else: 356 | self.crop_sample_size = crop_sample_size 357 | self.crop_step = crop_step 358 | 359 | # other parameters 360 | self.subject_id = subject_id 361 | self.data_generator = data_generator 362 | self.workers = workers 363 | self.save_model_at_each_epoch = save_model_at_each_epoch 364 | self.metrics_tracker = None 365 | 366 | # managing paths 367 | self.dl_results_dir = None 368 | self.model_results_dir = None 369 | self.fold_results_dir = None 370 | self.statistics_dir = None 371 | self.figures_dir = None 372 | self.tables_dir = None 373 | self.model_picture_path = None 374 | self.model_report_path = None 375 | self.train_report_path = None 376 | self.h5_models_dir = None 377 | self.h5_model_path = None 378 | self.log_path = None 379 | self.fold_stats_path = None 380 | self.paths_manager() 381 | 382 | # importing model 383 | print_manager('IMPORTING & COMPILING MODEL', 'double-dashed') 384 | model_inputs_str = ', '.join([str(i) for i in [self.n_classes, 385 | self.n_channels, 386 | self.crop_sample_size, 387 | self.dropout_rate]]) 388 | expression = 'models.' + self.model_name + '(' + model_inputs_str + ')' 389 | self.model = eval(expression) 390 | 391 | # creating optimizer instance 392 | if self.optimizer is 'Adam': 393 | opt = optimizers.Adam(lr=self.learning_rate) 394 | else: 395 | opt = optimizers.Adam(lr=self.learning_rate) 396 | 397 | # compiling model 398 | self.model.compile(loss=self.loss, 399 | optimizer=opt, 400 | metrics=['accuracy']) 401 | self.model.summary() 402 | print_manager('DONE!!', print_style='last', bottom_return=1) 403 | 404 | def __repr__(self): 405 | return ''.format(self.model_name) 406 | 407 | def __str__(self): 408 | return ''.format(self.model_name) 409 | 410 | def __len__(self): 411 | return len(self.dataset) 412 | 413 | @property 414 | def shape(self): 415 | return self.dataset.shape 416 | 417 | @property 418 | def train_frac(self): 419 | return self.dataset.train_frac 420 | 421 | @property 422 | def valid_frac(self): 423 | return self.dataset.valid_frac 424 | 425 | @property 426 | def test_frac(self): 427 | return self.dataset.test_frac 428 | 429 | @property 430 | def n_classes(self): 431 | return len(self.name_to_start_codes) 432 | 433 | @property 434 | def n_channels(self): 435 | return self.dataset.n_channels 436 | 437 | @property 438 | def n_samples(self): 439 | return self.dataset.n_samples 440 | 441 | def paths_manager(self): 442 | # results_dir is: .../results/hgdecode 443 | # dl_results_dir is: .../results/hgdecode/dl 444 | dl_results_dir = join(self.results_dir, 'dl') 445 | 446 | # model_results_dir is: .../results/hgdecode/dl/model_name 447 | model_results_dir = join(dl_results_dir, self.model_name) 448 | 449 | # fold_results_dir is .../results/dataset/dl/model/datetime/subj/fold 450 | fold_str = str(self.fold_idx + 1) 451 | if len(fold_str) == 1: 452 | fold_str = '0' + fold_str 453 | fold_str = 'fold' + fold_str 454 | fold_results_dir = join(self.subj_results_dir, fold_str) 455 | 456 | # setting on object self 457 | self.dl_results_dir = dl_results_dir 458 | self.model_results_dir = model_results_dir 459 | self.fold_results_dir = fold_results_dir 460 | 461 | # touching only the last directory will be create also the other ones 462 | touch_dir(self.fold_results_dir) 463 | 464 | # statistics_dir is: .../results/hgdecode/dl/model/datetime/statistics 465 | statistics_dir = join(self.datetime_results_dir, 'statistics') 466 | 467 | # figures_dir is: .../results/hgdecode/dl/model/dt/stat/figures/subject 468 | figures_dir = join(statistics_dir, 'figures', 469 | my_formatter(self.subject_id, 'subj')) 470 | 471 | # tables_dir is: .../results/hgdecode/dl/model/datetime/stat/tables 472 | tables_dir = join(statistics_dir, 'tables') 473 | 474 | # setting on object self 475 | self.statistics_dir = statistics_dir 476 | self.figures_dir = figures_dir 477 | self.tables_dir = tables_dir 478 | touch_dir(figures_dir) 479 | touch_dir(tables_dir) 480 | 481 | # files in datetime_results_dir 482 | self.model_report_path = join(self.datetime_results_dir, 483 | 'model_report.txt') 484 | self.model_picture_path = join(self.datetime_results_dir, 485 | 'model_picture.png') 486 | 487 | # files in subj_results_dir 488 | self.log_path = join(self.subj_results_dir, 'log.bin') 489 | 490 | # files in fold_results_dir 491 | self.train_report_path = join(self.fold_results_dir, 492 | 'train_report.csv') 493 | self.fold_stats_path = join(self.fold_results_dir, 'fold_stats.pickle') 494 | 495 | # if the user want to save the model on each epoch... 496 | if self.save_model_at_each_epoch: 497 | # ...creating models directory and an iterable name, else... 498 | self.h5_models_dir = join(self.fold_results_dir, 'h5_models') 499 | touch_dir(self.h5_models_dir) 500 | self.h5_model_path = join(self.h5_models_dir, 'net{epoch:02d}.h5') 501 | else: 502 | # ...pointing to the same results directory 503 | self.h5_model_path = join(self.fold_results_dir, 504 | 'net_best_val_loss.h5') 505 | 506 | def train(self): 507 | # saving a model picture 508 | # TODO: model_pic.png saving routine 509 | 510 | # saving a model report 511 | with open(self.model_report_path, 'w') as mr: 512 | self.model.summary(print_fn=lambda x: mr.write(x + '\n')) 513 | 514 | # pre-allocating callbacks list 515 | callbacks = [] 516 | 517 | # saving a train report 518 | csv = CSVLogger(self.train_report_path) 519 | callbacks.append(csv) 520 | 521 | # saving model each epoch 522 | if self.save_model_at_each_epoch: 523 | mcp = ModelCheckpoint(self.h5_model_path) 524 | callbacks.append(mcp) 525 | # else: 526 | # mcp = ModelCheckpoint(self.h5_model_path, 527 | # monitor='val_loss', 528 | # save_best_only=True) 529 | # callbacks.append(mcp) 530 | 531 | # if early_stopping is True... 532 | if self.early_stopping is True: 533 | # putting epochs to a very large number 534 | epochs = 1000 535 | 536 | # creating early stopping callback 537 | esc = EarlyStopping(monitor=self.monitor, 538 | min_delta=self.min_delta, 539 | patience=self.patience, 540 | verbose=1) 541 | callbacks.append(esc) 542 | else: 543 | # getting user defined epochs value 544 | epochs = self.epochs 545 | 546 | # using fit_generator if a data generator is required 547 | if self.data_generator is True: 548 | training_generator = EEGDataGenerator(self.dataset.X_train, 549 | self.dataset.y_train, 550 | self.batch_size, 551 | self.n_classes, 552 | self.crop_sample_size, 553 | self.crop_step) 554 | validation_generator = EEGDataGenerator(self.dataset.X_train, 555 | self.dataset.y_train, 556 | self.batch_size, 557 | self.n_classes, 558 | self.crop_sample_size, 559 | self.crop_step) 560 | 561 | # training! 562 | print_manager( 563 | 'RUNNING TRAINING ON FOLD {}'.format(self.fold_idx + 1), 564 | 'double-dashed' 565 | ) 566 | self.model.fit_generator(generator=training_generator, 567 | validation_data=validation_generator, 568 | use_multiprocessing=True, 569 | workers=self.workers, 570 | epochs=epochs, 571 | verbose=1, 572 | callbacks=callbacks) 573 | else: 574 | # creating crops 575 | self.dataset.make_crops(self.crop_sample_size, self.crop_step) 576 | 577 | # forcing the x examples to have 4 dimensions 578 | self.dataset.add_axis() 579 | 580 | # parsing y to categorical 581 | self.dataset.to_categorical() 582 | 583 | # TODO: MetricsTracker for Data Generation routine 584 | # creating a MetricsTracker instance 585 | if self.metrics_tracker is None: 586 | callbacks.append( 587 | MetricsTracker( 588 | dataset=self.dataset, 589 | epochs=self.epochs, 590 | n_classes=self.n_classes, 591 | batch_size=self.batch_size, 592 | h5_model_path=self.h5_model_path, 593 | fold_stats_path=self.fold_stats_path 594 | ) 595 | ) 596 | else: 597 | callbacks.append(self.metrics_tracker) 598 | 599 | # training! 600 | print_manager( 601 | 'RUNNING TRAINING ON FOLD {}'.format(self.fold_idx + 1), 602 | 'double-dashed' 603 | ) 604 | self.model.fit(x=self.dataset.X_train, 605 | y=self.dataset.y_train, 606 | validation_data=(self.dataset.X_valid, 607 | self.dataset.y_valid), 608 | batch_size=self.batch_size, 609 | epochs=epochs, 610 | verbose=1, 611 | callbacks=callbacks, 612 | shuffle=self.shuffle) 613 | # TODO: if validation_frac is 0 or None, not to split train and test 614 | # to train the epochs hyperparameter. 615 | 616 | def test(self): 617 | # TODO: evaluate_generator if data_generator is True 618 | # loading best net 619 | self.model.load_weights(self.h5_model_path) 620 | 621 | # computing loss and other metrics 622 | score = self.model.evaluate( 623 | self.dataset.X_test, 624 | self.dataset.y_test, 625 | verbose=1 626 | ) 627 | 628 | print('Test loss:', score[0]) 629 | print('Test acc:', score[1]) 630 | 631 | # making predictions on X_test with final model and getting also 632 | # y_test from memory; parsing both back from categorical 633 | y_pred = self.model.predict(self.dataset.X_test).argmax(axis=1) 634 | if self.data_generator is True: 635 | y_test = self.dataset.y_test 636 | else: 637 | y_test = self.dataset.y_test.argmax(axis=1) 638 | 639 | # computing confusion matrix 640 | conf_mtx = confusion_matrix(y_true=y_test, y_pred=y_pred) 641 | print("Confusion matrix:\n", conf_mtx) 642 | 643 | def prepare_for_transfer_learning(self, 644 | cross_subj_dir_path, 645 | subject_id, 646 | train_anyway=False): 647 | # printing the start 648 | print_manager('PREPARING FOR TRANSFER LEARNING', 'double-dashed') 649 | 650 | # getting this subject cross-subject dir 651 | cross_subj_this_subj_dir_path = join(cross_subj_dir_path, 652 | 'subj_cross', 653 | my_formatter(subject_id, 'fold')) 654 | 655 | # loading 656 | self.model.load_weights(join(cross_subj_this_subj_dir_path, 657 | 'net_best_val_loss.h5')) 658 | 659 | if train_anyway is False: 660 | # pre-saving this net as best one 661 | self.model.save(self.h5_model_path) 662 | 663 | # creating metrics tracker instance 664 | self.metrics_tracker = MetricsTracker( 665 | dataset=self.dataset, 666 | epochs=self.epochs, 667 | n_classes=self.n_classes, 668 | batch_size=self.batch_size, 669 | h5_model_path=self.h5_model_path, 670 | fold_stats_path=self.fold_stats_path 671 | ) 672 | 673 | # loading cross-subject info 674 | with open(join(cross_subj_this_subj_dir_path, 675 | 'fold_stats.pickle'), 'rb') as f: 676 | results = load(f) 677 | 678 | # forcing best net to be the 0 one 679 | self.metrics_tracker.best['loss'] = results['test']['loss'] 680 | self.metrics_tracker.best['acc'] = results['test']['acc'] 681 | self.metrics_tracker.best['idx'] = 0 682 | 683 | # printing the end 684 | print_manager('DONE!!', print_style='last', bottom_return=1) 685 | 686 | def freeze_layers(self, layers_to_freeze): 687 | print_manager('FREEZING LAYERS', 'double-dashed') 688 | if layers_to_freeze == 0: 689 | print('NOTHING TO FREEZE!!') 690 | else: 691 | print("I'm gonna gonna freeze {} layers.".format(layers_to_freeze)) 692 | 693 | # freezing layers 694 | frozen = 0 695 | if layers_to_freeze > 0: 696 | idx = 0 697 | step = 1 698 | else: 699 | idx = -1 700 | step = -1 701 | layers_to_freeze = - layers_to_freeze 702 | while frozen < layers_to_freeze: 703 | layer = self.model.layers[idx] 704 | if layer.name[:4] == 'conv' or layer.name[:5] == 'dense': 705 | layer.trainable = False 706 | frozen += 1 707 | idx += step 708 | 709 | # creating optimizer instance 710 | if self.optimizer is 'Adam': 711 | opt = optimizers.Adam(lr=self.learning_rate) 712 | else: 713 | opt = optimizers.Adam(lr=self.learning_rate) 714 | 715 | # compiling model 716 | self.model.compile(loss=self.loss, 717 | optimizer=opt, 718 | metrics=['accuracy']) 719 | 720 | # printing model information 721 | self.model.summary() 722 | print_manager('DONE!!', print_style='last', bottom_return=1) 723 | -------------------------------------------------------------------------------- /hgdecode/fbcsprlda.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import logging as log 4 | from copy import deepcopy 5 | from numpy import empty, mean, array 6 | from hgdecode.lda import lda_apply 7 | from hgdecode.lda import lda_train_scaled 8 | from hgdecode.signalproc import bandpass_mne 9 | from hgdecode.signalproc import select_trials 10 | from hgdecode.signalproc import calculate_csp 11 | from hgdecode.signalproc import select_classes 12 | from hgdecode.signalproc import apply_csp_var_log 13 | from hgdecode.signalproc import concatenate_channels 14 | from braindecode.datautil.iterators import get_balanced_batches 15 | from braindecode.datautil.trial_segment import \ 16 | create_signal_target_from_raw_mne 17 | 18 | 19 | class BinaryFBCSP(object): 20 | """ 21 | # TODO: a description for this class 22 | """ 23 | 24 | def __init__(self, 25 | cnt, 26 | clean_trial_mask, 27 | filterbands, 28 | filt_order, 29 | folds, 30 | class_pairs, 31 | epoch_ival_ms, 32 | n_filters, 33 | marker_def, 34 | name_to_stop_codes=None, 35 | average_trial_covariance=False): 36 | # cnt and signal parameters 37 | self.cnt = cnt 38 | self.clean_trial_mask = clean_trial_mask 39 | self.epoch_ival_ms = epoch_ival_ms 40 | self.marker_def = marker_def 41 | self.name_to_stop_codes = name_to_stop_codes 42 | 43 | # filter bank parameters 44 | self.filterbands = filterbands 45 | self.filt_order = filt_order 46 | 47 | # machine learning parameters 48 | self.folds = folds 49 | self.n_filters = n_filters 50 | self.class_pairs = class_pairs 51 | self.average_trial_covariance = average_trial_covariance 52 | 53 | # getting result shape 54 | n_filterbands = len(self.filterbands) 55 | n_folds = len(self.folds) 56 | n_class_pairs = len(self.class_pairs) 57 | result_shape = (n_filterbands, n_folds, n_class_pairs) 58 | 59 | # creating result related properties and pre-allocating them 60 | self.filters = empty(result_shape, dtype=object) 61 | self.patterns = empty(result_shape, dtype=object) 62 | self.variances = empty(result_shape, dtype=object) 63 | self.train_feature = empty(result_shape, dtype=object) 64 | self.test_feature = empty(result_shape, dtype=object) 65 | self.train_feature_full_fold = empty(result_shape, dtype=object) 66 | self.test_feature_full_fold = empty(result_shape, dtype=object) 67 | self.clf = empty(result_shape, dtype=object) 68 | self.train_accuracy = empty(result_shape, dtype=object) 69 | self.test_accuracy = empty(result_shape, dtype=object) 70 | self.train_labels_full_fold = empty(len(self.folds), dtype=object) 71 | self.test_labels_full_fold = empty(len(self.folds), dtype=object) 72 | self.train_labels = empty( 73 | (len(self.folds), len(self.class_pairs)), 74 | dtype=object) 75 | self.test_labels = empty( 76 | (len(self.folds), len(self.class_pairs)), 77 | dtype=object) 78 | 79 | def run(self): 80 | # %% CYCLING ON FILTERS IN FILTERBANK 81 | # %% 82 | # just for me: enumerate is a really powerful built-in python 83 | # function that allows you to loop over something and have an 84 | # automatic counter. In this case, bp_nr is the counter, 85 | # then filt_band is the default exit for the method getitem for 86 | # filterbands class. 87 | for bp_nr, filt_band in enumerate(self.filterbands): 88 | # printing filter information 89 | self.print_filter(bp_nr) 90 | 91 | # bandpassing all the cnt RawArray with the current filter 92 | bandpassed_cnt = bandpass_mne( 93 | self.cnt, 94 | filt_band[0], 95 | filt_band[1], 96 | filt_order=self.filt_order 97 | ) 98 | 99 | # epoching: from cnt data to epoched data 100 | epo = create_signal_target_from_raw_mne( 101 | bandpassed_cnt, 102 | name_to_start_codes=self.marker_def, 103 | epoch_ival_ms=self.epoch_ival_ms, 104 | name_to_stop_codes=self.name_to_stop_codes 105 | ) 106 | 107 | # cleaning epoched data with clean_trial_mask (finally) 108 | if len(self.folds) != 1: 109 | epo.X = epo.X[self.clean_trial_mask] 110 | epo.y = epo.y[self.clean_trial_mask] 111 | 112 | # %% CYCLING ON FOLDS 113 | # %% 114 | for fold_nr in range(len(self.folds)): 115 | # printing fold information 116 | self.print_fold_nr(fold_nr) 117 | 118 | # getting information on current fold 119 | train_test = self.folds[fold_nr] 120 | 121 | # getting train and test indexes 122 | train_ind = train_test['train'] 123 | test_ind = train_test['test'] 124 | 125 | # getting train data from train indexes 126 | epo_train = select_trials(epo, train_ind) 127 | 128 | # getting test data from test indexes 129 | epo_test = select_trials(epo, test_ind) 130 | 131 | # logging info on train 132 | log.info("#Train trials: {:4d}".format(len(epo_train.X))) 133 | 134 | # logging info on test 135 | log.info("#Test trials : {:4d}".format(len(epo_test.X))) 136 | 137 | # setting train labels of this fold 138 | self.train_labels_full_fold[fold_nr] = epo_train.y 139 | 140 | # setting test labels of this fold 141 | self.test_labels_full_fold[fold_nr] = epo_test.y 142 | 143 | # %% CYCLING ON ALL POSSIBLE CLASS PAIRS 144 | # %% 145 | for pair_nr in range(len(self.class_pairs)): 146 | # getting class pair from index (pair_nr) 147 | class_pair = self.class_pairs[pair_nr] 148 | 149 | # printing class pair information 150 | self.print_class_pair(class_pair) 151 | 152 | # getting train trials only for current two classes 153 | epo_train_pair = select_classes(epo_train, class_pair) 154 | 155 | # getting test trials only for current two classes 156 | epo_test_pair = select_classes(epo_test, class_pair) 157 | 158 | # saving train labels for this two classes 159 | self.train_labels[fold_nr][pair_nr] = epo_train_pair.y 160 | 161 | # saving test labels for this two classes 162 | self.test_labels[fold_nr][pair_nr] = epo_test_pair.y 163 | 164 | # %% COMPUTING CSP 165 | # %% 166 | filters, patterns, variances = calculate_csp( 167 | epo_train_pair, 168 | average_trial_covariance=self.average_trial_covariance 169 | ) 170 | 171 | # %% FEATURE EXTRACTION 172 | # %% 173 | # choosing how many spacial filter to apply; 174 | # if no spacial filter number specified... 175 | if self.n_filters is None: 176 | # ...taking all columns, else... 177 | columns = list(range(len(filters))) 178 | else: 179 | # ...take topmost and bottommost filters; 180 | # e.g. for n_filters=3 we are going to pick: 181 | # 0, 1, 2, -3, -2, -1 182 | columns = (list(range(0, self.n_filters)) + 183 | list(range(-self.n_filters, 0))) 184 | 185 | # feature extraction on train 186 | train_feature = apply_csp_var_log( 187 | epo_train_pair, 188 | filters, 189 | columns 190 | ) 191 | 192 | # feature extraction on test 193 | test_feature = apply_csp_var_log( 194 | epo_test_pair, 195 | filters, 196 | columns 197 | ) 198 | 199 | # %% COMPUTING LDA USING TRAIN FEATURES 200 | # %% 201 | # clf is a 1x2 tuple where: 202 | # * clf[0] is hyperplane parameters 203 | # * clf[1] is hyperplane bias 204 | # with clf, you can recreate the n-dimensional 205 | # hyperplane that splits class space, so you can 206 | # classify your fbcsp extracted features. 207 | clf = lda_train_scaled(train_feature, shrink=True) 208 | 209 | # %% APPLYING LDA ON TRAIN 210 | # %% 211 | # applying LDA 212 | train_out = lda_apply(train_feature, clf) 213 | 214 | # getting true/false labels instead of class labels 215 | # for example, if you have: 216 | # train_feature.y --> [1, 3, 3, 1] 217 | # class_pair --> [1, 3] 218 | # so you will have: 219 | # true_0_1_labels_train = [False, True, True, False] 220 | true_0_1_labels_train = train_feature.y == class_pair[1] 221 | 222 | # if predicted output grater than 0 True, False instead 223 | predicted_train = train_out >= 0 224 | 225 | # computing accuracy 226 | # if mean has a boolean array as input, it will 227 | # compute number of True elements divided by total 228 | # number of elements, so the accuracy 229 | train_accuracy = mean( 230 | true_0_1_labels_train == predicted_train 231 | ) 232 | 233 | # %% APPLYING LDA ON TEST 234 | # %% 235 | # same procedure 236 | test_out = lda_apply(test_feature, clf) 237 | true_0_1_labels_test = test_feature.y == class_pair[1] 238 | predicted_test = test_out >= 0 239 | test_accuracy = mean( 240 | true_0_1_labels_test == predicted_test 241 | ) 242 | 243 | # %% FEATURE COMPUTATION FOR FULL FOLD 244 | # %% (FOR LATER MULTICLASS) 245 | # here we use csp computed only for this pair of classes 246 | # to compute feature for all the current fold 247 | # train here 248 | train_feature_full_fold = apply_csp_var_log( 249 | epo_train, 250 | filters, 251 | columns 252 | ) 253 | 254 | # test here 255 | test_feature_full_fold = apply_csp_var_log( 256 | epo_test, 257 | filters, 258 | columns 259 | ) 260 | 261 | # %% STORE RESULTS 262 | # %% 263 | # only store used patterns filters variances 264 | # to save memory space on disk 265 | self.store_results( 266 | bp_nr, 267 | fold_nr, 268 | pair_nr, 269 | filters[:, columns], 270 | patterns[:, columns], 271 | variances[columns], 272 | train_feature, 273 | test_feature, 274 | train_feature_full_fold, 275 | test_feature_full_fold, 276 | clf, 277 | train_accuracy, 278 | test_accuracy 279 | ) 280 | 281 | # printing the end of this super-nested cycle 282 | self.print_results(bp_nr, fold_nr, pair_nr) 283 | 284 | # printing a blank line to divide filters 285 | print() 286 | 287 | def store_results(self, 288 | bp_nr, 289 | fold_nr, 290 | pair_nr, 291 | filters, 292 | patterns, 293 | variances, 294 | train_feature, 295 | test_feature, 296 | train_feature_full_fold, 297 | test_feature_full_fold, 298 | clf, 299 | train_accuracy, 300 | test_accuracy): 301 | """ Store all supplied arguments to this objects dict, at the correct 302 | indices for filterband / fold / class_pair.""" 303 | local_vars = locals() 304 | del local_vars['self'] 305 | del local_vars['bp_nr'] 306 | del local_vars['fold_nr'] 307 | del local_vars['pair_nr'] 308 | for var in local_vars: 309 | self.__dict__[var][bp_nr, fold_nr, pair_nr] = local_vars[var] 310 | 311 | def print_filter(self, bp_nr): 312 | # distinguish filter blocks by empty line 313 | log.info( 314 | "Filter {:d}/{:d}, {:4.2f} to {:4.2f} Hz".format( 315 | bp_nr + 1, 316 | len(self.filterbands), 317 | *self.filterbands[bp_nr]) 318 | ) 319 | 320 | @staticmethod 321 | def print_fold_nr(fold_nr): 322 | log.info("Fold Nr: {:d}".format(fold_nr + 1)) 323 | 324 | @staticmethod 325 | def print_class_pair(class_pair): 326 | class_pair_plus_one = (array(class_pair) + 1).tolist() 327 | log.info("Class {:d} vs {:d}".format(*class_pair_plus_one)) 328 | 329 | def print_results(self, bp_nr, fold_nr, pair_nr): 330 | log.info("Train: {:4.2f}%".format( 331 | self.train_accuracy[bp_nr, fold_nr, pair_nr] * 100)) 332 | log.info("Test: {:4.2f}%".format( 333 | self.test_accuracy[bp_nr, fold_nr, pair_nr] * 100)) 334 | 335 | 336 | class FBCSP(object): 337 | """ 338 | # TODO: a description for this class 339 | """ 340 | 341 | def __init__(self, 342 | binary_csp, 343 | n_features=None, 344 | n_filterbands=None, 345 | forward_steps=2, 346 | backward_steps=1, 347 | stop_when_no_improvement=False): 348 | # copying inputs 349 | self.binary_csp = binary_csp 350 | self.n_features = n_features 351 | self.n_filterbands = n_filterbands 352 | self.forward_steps = forward_steps 353 | self.backward_steps = backward_steps 354 | self.stop_when_no_improvement = stop_when_no_improvement 355 | 356 | # pre-allocating other properties 357 | self.train_feature = None 358 | self.train_feature_full_fold = None 359 | self.test_feature = None 360 | self.test_feature_full_fold = None 361 | self.selected_filter_inds = None 362 | self.selected_filters_per_filterband = None 363 | self.selected_features = None 364 | self.clf = None 365 | self.train_accuracy = None 366 | self.test_accuracy = None 367 | self.train_pred_full_fold = None 368 | self.test_pred_full_fold = None 369 | 370 | def run(self): 371 | self.select_filterbands() 372 | if self.n_features is not None: 373 | log.info("Run feature selection...") 374 | self.collect_best_features() 375 | log.info("Done.") 376 | else: 377 | self.collect_features() 378 | self.train_classifiers() 379 | self.predict_outputs() 380 | 381 | def select_filterbands(self): 382 | n_all_filterbands = len(self.binary_csp.filterbands) 383 | if self.n_filterbands is None: 384 | self.selected_filter_inds = list(range(n_all_filterbands)) 385 | else: 386 | # Select the filterbands with the highest mean accuracy on the 387 | # training sets 388 | mean_accs = np.mean(self.binary_csp.train_accuracy, axis=(1, 2)) 389 | best_filters = np.argsort(mean_accs)[::-1][:self.n_filterbands] 390 | self.selected_filter_inds = best_filters 391 | 392 | def collect_features(self): 393 | n_folds = len(self.binary_csp.folds) 394 | n_class_pairs = len(self.binary_csp.class_pairs) 395 | result_shape = (n_folds, n_class_pairs) 396 | self.train_feature = np.empty(result_shape, dtype=object) 397 | self.train_feature_full_fold = np.empty(result_shape, dtype=object) 398 | self.test_feature = np.empty(result_shape, dtype=object) 399 | self.test_feature_full_fold = np.empty(result_shape, dtype=object) 400 | 401 | bcsp = self.binary_csp # just to make code shorter 402 | filter_inds = self.selected_filter_inds 403 | for fold_i in range(n_folds): 404 | for class_i in range(n_class_pairs): 405 | self.train_feature[fold_i, class_i] = concatenate_channels( 406 | bcsp.train_feature[filter_inds, fold_i, class_i]) 407 | self.train_feature_full_fold[fold_i, class_i] = ( 408 | concatenate_channels( 409 | bcsp.train_feature_full_fold[ 410 | filter_inds, fold_i, class_i])) 411 | self.test_feature[fold_i, class_i] = concatenate_channels( 412 | bcsp.test_feature[filter_inds, fold_i, class_i] 413 | ) 414 | self.test_feature_full_fold[fold_i, class_i] = ( 415 | concatenate_channels( 416 | bcsp.test_feature_full_fold[ 417 | filter_inds, fold_i, class_i] 418 | )) 419 | 420 | def collect_best_features(self): 421 | """ Selects features filterwise per filterband, starting with no 422 | features, then selecting the best filterpair from the bestfilterband 423 | (measured on internal train/test split)""" 424 | bincsp = self.binary_csp # just to make code shorter 425 | 426 | # getting dimension for feature arrays 427 | n_folds = len(self.binary_csp.folds) 428 | n_class_pairs = len(self.binary_csp.class_pairs) 429 | result_shape = (n_folds, n_class_pairs) 430 | 431 | # initializing feature array for this classes 432 | self.train_feature = np.empty(result_shape, dtype=object) 433 | self.train_feature_full_fold = np.empty(result_shape, dtype=object) 434 | self.test_feature = np.empty(result_shape, dtype=object) 435 | self.test_feature_full_fold = np.empty(result_shape, dtype=object) 436 | self.selected_filters_per_filterband = np.empty(result_shape, 437 | dtype=object) 438 | # outer cycle on folds 439 | for fold_i in range(n_folds): 440 | # inner cycle on pairs 441 | for class_pair_i in range(n_class_pairs): 442 | # saving bincsp features locally (prevent ram to modify values) 443 | bin_csp_train_features = deepcopy( 444 | bincsp.train_feature[ 445 | self.selected_filter_inds, fold_i, class_pair_i 446 | ] 447 | ) 448 | bin_csp_train_features_full_fold = deepcopy( 449 | bincsp.train_feature_full_fold[ 450 | self.selected_filter_inds, 451 | fold_i, class_pair_i 452 | ] 453 | ) 454 | bin_csp_test_features = deepcopy( 455 | bincsp.test_feature[ 456 | self.selected_filter_inds, 457 | fold_i, 458 | class_pair_i 459 | ] 460 | ) 461 | bin_csp_test_features_full_fold = deepcopy( 462 | bincsp.test_feature_full_fold[ 463 | self.selected_filter_inds, fold_i, class_pair_i 464 | ] 465 | ) 466 | 467 | # selecting best filters 468 | selected_filters_per_filt = \ 469 | self.select_best_filters_best_filterbands( 470 | bin_csp_train_features, 471 | max_features=self.n_features, 472 | forward_steps=self.forward_steps, 473 | backward_steps=self.backward_steps, 474 | stop_when_no_improvement=self.stop_when_no_improvement 475 | ) 476 | 477 | # collecting train features 478 | self.train_feature[fold_i, class_pair_i] = \ 479 | self.collect_features_for_filter_selection( 480 | bin_csp_train_features, 481 | selected_filters_per_filt 482 | ) 483 | 484 | # collecting train features full fold 485 | self.train_feature_full_fold[fold_i, class_pair_i] = \ 486 | self.collect_features_for_filter_selection( 487 | bin_csp_train_features_full_fold, 488 | selected_filters_per_filt 489 | ) 490 | 491 | # collecting test features 492 | self.test_feature[fold_i, class_pair_i] = \ 493 | self.collect_features_for_filter_selection( 494 | bin_csp_test_features, 495 | selected_filters_per_filt 496 | ) 497 | 498 | # collecting test features full fold 499 | self.test_feature_full_fold[fold_i, class_pair_i] = \ 500 | self.collect_features_for_filter_selection( 501 | bin_csp_test_features_full_fold, 502 | selected_filters_per_filt 503 | ) 504 | 505 | # saving also the filters selected for this fold and pair 506 | self.selected_filters_per_filterband[fold_i, class_pair_i] = \ 507 | selected_filters_per_filt 508 | 509 | @staticmethod 510 | def select_best_filters_best_filterbands(features, 511 | max_features, 512 | forward_steps, 513 | backward_steps, 514 | stop_when_no_improvement): 515 | n_filterbands = len(features) 516 | n_filters_per_fb = features[0].X.shape[1] / 2 517 | selected_filters_per_band = [0] * n_filterbands 518 | best_selected_filters_per_filterband = None 519 | last_best_accuracy = -1 520 | 521 | # Run until no improvement or max features reached 522 | selection_finished = False 523 | while not selection_finished: 524 | for _ in range(forward_steps): 525 | best_accuracy = -1 526 | 527 | # let's try always taking a feature in each iteration 528 | for filt_i in range(n_filterbands): 529 | this_filt_per_fb = deepcopy(selected_filters_per_band) 530 | if this_filt_per_fb[filt_i] == n_filters_per_fb: 531 | continue 532 | this_filt_per_fb[filt_i] = this_filt_per_fb[filt_i] + 1 533 | all_features = \ 534 | FBCSP.collect_features_for_filter_selection( 535 | features, 536 | this_filt_per_fb 537 | ) 538 | 539 | # make 5 times cross validation... 540 | test_accuracy = FBCSP.cross_validate_lda( 541 | all_features 542 | ) 543 | 544 | if test_accuracy > best_accuracy: 545 | best_accuracy = test_accuracy 546 | best_selected_filters_per_filterband = this_filt_per_fb 547 | 548 | selected_filters_per_band = \ 549 | best_selected_filters_per_filterband 550 | 551 | for _ in range(backward_steps): 552 | best_accuracy = -1 553 | # let's try always taking a feature in each iteration 554 | for filt_i in range(n_filterbands): 555 | this_filt_per_fb = deepcopy(selected_filters_per_band) 556 | if this_filt_per_fb[filt_i] == 0: 557 | continue 558 | this_filt_per_fb[filt_i] = this_filt_per_fb[filt_i] - 1 559 | all_features = \ 560 | FBCSP.collect_features_for_filter_selection( 561 | features, 562 | this_filt_per_fb 563 | ) 564 | # make 5 times cross validation... 565 | test_accuracy = FBCSP.cross_validate_lda( 566 | all_features 567 | ) 568 | if test_accuracy > best_accuracy: 569 | best_accuracy = test_accuracy 570 | best_selected_filters_per_filterband = this_filt_per_fb 571 | selected_filters_per_band = \ 572 | best_selected_filters_per_filterband 573 | 574 | selection_finished = 2 * np.sum( 575 | selected_filters_per_band) >= max_features 576 | if stop_when_no_improvement: 577 | # there was no improvement if accuracy did not increase... 578 | selection_finished = (selection_finished 579 | or best_accuracy <= last_best_accuracy) 580 | last_best_accuracy = best_accuracy 581 | return selected_filters_per_band 582 | 583 | @staticmethod 584 | def collect_features_for_filter_selection( 585 | features, 586 | filters_for_filterband 587 | ): 588 | n_filters_per_fb = features[0].X.shape[1] // 2 589 | n_filterbands = len(features) 590 | # start with filters of first filterband... 591 | # then add others all together 592 | first_features = deepcopy(features[0]) 593 | first_n_filters = filters_for_filterband[0] 594 | if first_n_filters == 0: 595 | first_features.X = first_features.X[:, 0:0] 596 | else: 597 | first_features.X = \ 598 | first_features.X[ 599 | :, 600 | list(range(first_n_filters)) + 601 | list(range(-first_n_filters, 0)) 602 | ] 603 | 604 | all_features = first_features 605 | for i in range(1, n_filterbands): 606 | this_n_filters = min(n_filters_per_fb, filters_for_filterband[i]) 607 | if this_n_filters > 0: 608 | next_features = deepcopy(features[i]) 609 | if this_n_filters == 0: 610 | next_features.X = next_features.X[0:0] 611 | else: 612 | next_features.X = \ 613 | next_features.X[ 614 | :, 615 | list(range(this_n_filters)) + 616 | list(range(-this_n_filters, 0)) 617 | ] 618 | all_features = concatenate_channels( 619 | (all_features, next_features)) 620 | return all_features 621 | 622 | @staticmethod 623 | def cross_validate_lda(features): 624 | n_trials = features.X.shape[0] 625 | folds = get_balanced_batches(n_trials, rng=None, shuffle=False, 626 | n_batches=5) 627 | # make to train-test splits, fold is test part.. 628 | folds = [(np.setdiff1d(np.arange(n_trials), fold), 629 | fold) for fold in folds] 630 | test_accuracies = [] 631 | for train_inds, test_inds in folds: 632 | train_features = select_trials(features, train_inds) 633 | test_features = select_trials(features, test_inds) 634 | clf = lda_train_scaled(train_features, shrink=True) 635 | test_out = lda_apply(test_features, clf) 636 | 637 | higher_class = np.max(test_features.y) 638 | true_0_1_labels_test = test_features.y == higher_class 639 | 640 | predicted_test = test_out >= 0 641 | test_accuracy = np.mean(true_0_1_labels_test == predicted_test) 642 | test_accuracies.append(test_accuracy) 643 | return np.mean(test_accuracies) 644 | 645 | def train_classifiers(self): 646 | n_folds = len(self.binary_csp.folds) 647 | n_class_pairs = len(self.binary_csp.class_pairs) 648 | self.clf = np.empty((n_folds, n_class_pairs), 649 | dtype=object) 650 | for fold_i in range(n_folds): 651 | for class_i in range(n_class_pairs): 652 | train_feature = self.train_feature[fold_i, class_i] 653 | clf = lda_train_scaled(train_feature, shrink=True) 654 | self.clf[fold_i, class_i] = clf 655 | 656 | def predict_outputs(self): 657 | n_folds = len(self.binary_csp.folds) 658 | n_class_pairs = len(self.binary_csp.class_pairs) 659 | result_shape = (n_folds, n_class_pairs) 660 | self.train_accuracy = np.empty(result_shape, dtype=float) 661 | self.test_accuracy = np.empty(result_shape, dtype=float) 662 | self.train_pred_full_fold = np.empty(result_shape, dtype=object) 663 | self.test_pred_full_fold = np.empty(result_shape, dtype=object) 664 | for fold_i in range(n_folds): 665 | log.info("Fold Nr: {:d}".format(fold_i + 1)) 666 | for class_i, class_pair in enumerate(self.binary_csp.class_pairs): 667 | clf = self.clf[fold_i, class_i] 668 | class_pair_plus_one = (np.array(class_pair) + 1).tolist() 669 | log.info("Class {:d} vs {:d}".format(*class_pair_plus_one)) 670 | train_feature = self.train_feature[fold_i, class_i] 671 | train_out = lda_apply(train_feature, clf) 672 | true_0_1_labels_train = train_feature.y == class_pair[1] 673 | predicted_train = train_out >= 0 674 | # remove xarray wrapper with float( ... 675 | train_accuracy = float(np.mean(true_0_1_labels_train 676 | == predicted_train)) 677 | self.train_accuracy[fold_i, class_i] = train_accuracy 678 | 679 | test_feature = self.test_feature[fold_i, class_i] 680 | test_out = lda_apply(test_feature, clf) 681 | true_0_1_labels_test = test_feature.y == class_pair[1] 682 | predicted_test = test_out >= 0 683 | test_accuracy = float(np.mean(true_0_1_labels_test 684 | == predicted_test)) 685 | 686 | self.test_accuracy[fold_i, class_i] = test_accuracy 687 | 688 | train_feature_full_fold = self.train_feature_full_fold[fold_i, 689 | class_i] 690 | train_out_full_fold = lda_apply(train_feature_full_fold, clf) 691 | self.train_pred_full_fold[ 692 | fold_i, class_i] = train_out_full_fold 693 | test_feature_full_fold = self.test_feature_full_fold[fold_i, 694 | class_i] 695 | test_out_full_fold = lda_apply(test_feature_full_fold, clf) 696 | self.test_pred_full_fold[fold_i, class_i] = test_out_full_fold 697 | 698 | log.info("Train: {:4.2f}%".format(train_accuracy * 100)) 699 | log.info("Test: {:4.2f}%".format(test_accuracy * 100)) 700 | 701 | 702 | class MultiClassWeightedVoting(object): 703 | """ 704 | # TODO: a description for this class 705 | """ 706 | 707 | def __init__(self, 708 | train_labels, 709 | test_labels, 710 | train_preds, 711 | test_preds, 712 | class_pairs): 713 | # copying input parameters 714 | self.train_labels = train_labels 715 | self.test_labels = test_labels 716 | self.train_preds = train_preds 717 | self.test_preds = test_preds 718 | self.class_pairs = class_pairs 719 | 720 | # pre-allocating other useful properties 721 | self.train_class_sums = None 722 | self.test_class_sums = None 723 | self.train_predicted_labels = None 724 | self.test_predicted_labels = None 725 | self.train_accuracy = None 726 | self.test_accuracy = None 727 | 728 | def run(self): 729 | # determine number of classes by number of unique classes 730 | # appearing in class pairs 731 | n_classes = len(np.unique(list(itertools.chain(*self.class_pairs)))) 732 | n_folds = len(self.train_labels) 733 | self.train_class_sums = np.empty(n_folds, dtype=object) 734 | self.test_class_sums = np.empty(n_folds, dtype=object) 735 | self.train_predicted_labels = np.empty(n_folds, dtype=object) 736 | self.test_predicted_labels = np.empty(n_folds, dtype=object) 737 | self.train_accuracy = np.ones(n_folds) * np.nan 738 | self.test_accuracy = np.ones(n_folds) * np.nan 739 | for fold_nr in range(n_folds): 740 | log.info("Fold Nr: {:d}".format(fold_nr + 1)) 741 | train_labels = self.train_labels[fold_nr] 742 | train_preds = self.train_preds[fold_nr] 743 | train_class_sums = np.zeros((len(train_labels), n_classes)) 744 | 745 | test_labels = self.test_labels[fold_nr] 746 | test_preds = self.test_preds[fold_nr] 747 | test_class_sums = np.zeros((len(test_labels), n_classes)) 748 | for pair_i, class_pair in enumerate(self.class_pairs): 749 | this_train_preds = train_preds[pair_i] 750 | assert len(this_train_preds) == len(train_labels) 751 | train_class_sums[:, class_pair[0]] -= this_train_preds 752 | train_class_sums[:, class_pair[1]] += this_train_preds 753 | this_test_preds = test_preds[pair_i] 754 | assert len(this_test_preds) == len(test_labels) 755 | test_class_sums[:, class_pair[0]] -= this_test_preds 756 | test_class_sums[:, class_pair[1]] += this_test_preds 757 | 758 | self.train_class_sums[fold_nr] = train_class_sums 759 | self.test_class_sums[fold_nr] = test_class_sums 760 | train_predicted_labels = np.argmax(train_class_sums, axis=1) 761 | test_predicted_labels = np.argmax(test_class_sums, axis=1) 762 | self.train_predicted_labels[fold_nr] = train_predicted_labels 763 | self.test_predicted_labels[fold_nr] = test_predicted_labels 764 | train_accuracy = (np.sum(train_predicted_labels == train_labels) / 765 | float(len(train_labels))) 766 | self.train_accuracy[fold_nr] = train_accuracy 767 | test_accuracy = (np.sum(test_predicted_labels == test_labels) / 768 | float(len(test_labels))) 769 | self.test_accuracy[fold_nr] = test_accuracy 770 | log.info("Train: {:4.2f}%".format(train_accuracy * 100)) 771 | log.info("Test: {:4.2f}%".format(test_accuracy * 100)) 772 | -------------------------------------------------------------------------------- /hgdecode/lda.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.covariance import LedoitWolf 3 | 4 | 5 | def lda_train_scaled(fv, shrink=False): 6 | """Train the LDA classifier. 7 | 8 | Parameters 9 | ---------- 10 | fv : ``Data`` object 11 | the feature vector must have 2 dimensional data, the first 12 | dimension being the class axis. The unique class labels must be 13 | 0 and 1 otherwise a ``ValueError`` will be raised. 14 | shrink : Boolean, optional 15 | use shrinkage 16 | 17 | Returns 18 | ------- 19 | w : 1d array 20 | b : float 21 | 22 | Raises 23 | ------ 24 | ValueError : if the class labels are not exactly 0s and 1s 25 | 26 | Examples 27 | -------- 28 | 29 | >> clf = lda_train(fv_train) 30 | >> out = lda_apply(fv_test, clf) 31 | 32 | See Also 33 | -------- 34 | lda_apply 35 | 36 | """ 37 | assert shrink is True 38 | assert fv.X.ndim == 2 39 | x = fv.X 40 | y = fv.y 41 | if len(np.unique(y)) != 2: 42 | raise ValueError( 43 | 'Should only have two unique class labels, instead got' 44 | ': {labels}'.format(labels=np.unique(y))) 45 | # Use sorted labels 46 | labels = np.sort(np.unique(y)) 47 | mu1 = np.mean(x[y == labels[0]], axis=0) 48 | mu2 = np.mean(x[y == labels[1]], axis=0) 49 | # x' = x - m 50 | m = np.empty(x.shape) 51 | m[y == labels[0]] = mu1 52 | m[y == labels[1]] = mu2 53 | x2 = x - m 54 | # w = cov(x)^-1(mu2 - mu1) 55 | if shrink: 56 | estimator = LedoitWolf() 57 | covm = estimator.fit(x2).covariance_ 58 | else: 59 | covm = np.cov(x2.T) 60 | w = np.dot(np.linalg.pinv(covm), (mu2 - mu1)) 61 | 62 | # From MATLAB bbci toolbox: 63 | # https://github.com/bbci/bbci_public/blob/ 64 | # fe6caeb549fdc864a5accf76ce71dd2a926ff12b/classification/ 65 | # train_RLDAshrink.m#L133-L134 66 | # C.w= C.w/(C.w'*diff(C_mean, 1, 2))*2; 67 | # C.b= -C.w' * mean(C_mean,2); 68 | w = (w / np.dot(w.T, (mu2 - mu1))) * 2 69 | b = np.dot(-w.T, np.mean((mu1, mu2), axis=0)) 70 | assert not np.any(np.isnan(w)) 71 | assert not np.isnan(b) 72 | return w, b 73 | 74 | 75 | def lda_apply(fv, clf): 76 | """Apply feature vector to LDA classifier. 77 | 78 | Parameters 79 | ---------- 80 | fv : ``Data`` object 81 | the feature vector must have a 2 dimensional data, the first 82 | dimension being the class axis. 83 | clf : (1d array, float) 84 | 85 | Returns 86 | ------- 87 | 88 | out : 1d array 89 | The projection of the data on the hyperplane. 90 | 91 | Examples 92 | -------- 93 | 94 | >> clf = lda_train(fv_train) 95 | >> out = lda_apply(fv_test, clf) 96 | 97 | 98 | See Also 99 | -------- 100 | lda_train 101 | 102 | """ 103 | x = fv.X 104 | w, b = clf 105 | return np.dot(x, w) + b 106 | -------------------------------------------------------------------------------- /hgdecode/loaders.py: -------------------------------------------------------------------------------- 1 | import logging as log 2 | from copy import deepcopy 3 | from numpy import max 4 | from numpy import abs 5 | from numpy import sum 6 | from numpy import std 7 | from numpy import mean 8 | from numpy import array 9 | from numpy import floor 10 | from numpy import repeat 11 | from numpy import arange 12 | from numpy import setdiff1d 13 | from numpy import concatenate 14 | from numpy import count_nonzero 15 | from numpy.random import RandomState 16 | from os.path import join 17 | from mne.io.array.array import RawArray 18 | from hgdecode.utils import print_manager 19 | from hgdecode.classes import CrossValidation 20 | from sklearn.model_selection import StratifiedKFold 21 | from braindecode.datasets.bbci import BBCIDataset 22 | from braindecode.mne_ext.signalproc import mne_apply 23 | from braindecode.mne_ext.signalproc import resample_cnt 24 | from braindecode.mne_ext.signalproc import concatenate_raws_with_events 25 | from braindecode.datautil.signalproc import bandpass_cnt 26 | from braindecode.datautil.signalproc import exponential_running_standardize 27 | from braindecode.datautil.signal_target import SignalAndTarget 28 | from braindecode.datautil.trial_segment import \ 29 | create_signal_target_from_raw_mne 30 | 31 | 32 | # TODO: re-implement all this functions as an unique class 33 | 34 | 35 | def get_data_files_paths(data_dir, subject_id=1, train_test_split=True): 36 | # compute file name (for both train and test path) 37 | file_name = '{:d}.mat'.format(subject_id) 38 | 39 | # compute file paths 40 | if train_test_split: 41 | train_file_path = join(data_dir, 'train', file_name) 42 | test_file_path = join(data_dir, 'test', file_name) 43 | file_path = [train_file_path, test_file_path] 44 | else: 45 | file_path = [join(data_dir, 'train', file_name)] 46 | 47 | # return paths 48 | return file_path 49 | 50 | 51 | def load_cnt(file_path, channel_names, clean_on_all_channels=True): 52 | # if we have to run the cleaning procedure on all channels, putting 53 | # load_sensor_names to None will assure us the BBCIDataset class will 54 | # load all possible sensors 55 | if clean_on_all_channels is True: 56 | channel_names = None 57 | 58 | # create the loader object for BBCI standard 59 | loader = BBCIDataset(file_path, load_sensor_names=channel_names) 60 | 61 | # load data 62 | return loader.load() 63 | 64 | 65 | def get_clean_trial_mask(cnt, name_to_start_codes, clean_ival_ms=(0, 4000)): 66 | """ 67 | Scan trial in continuous data and create a mask with only the 68 | valid ones; in this way, at the and of the loading routine, 69 | after all the data pre-processing, you will be able to cut away 70 | the original not valid data. 71 | """ 72 | # split cnt into trials data for cleaning 73 | set_for_cleaning = create_signal_target_from_raw_mne( 74 | cnt, 75 | name_to_start_codes, 76 | clean_ival_ms 77 | ) 78 | 79 | # compute the clean_trial_mask: in this case we take only all 80 | # trials that have absolute microvolt values larger than +- 800 81 | clean_trial_mask = max(abs(set_for_cleaning.X), axis=(1, 2)) < 800 82 | 83 | # logging clean trials information 84 | log.info( 85 | 'Clean trials: {:3d} of {:3d} ({:5.1f}%)'.format( 86 | sum(clean_trial_mask), 87 | len(set_for_cleaning.X), 88 | mean(clean_trial_mask) * 100) 89 | ) 90 | 91 | # return the clean_trial_mask 92 | return clean_trial_mask 93 | 94 | 95 | def pick_right_channels(cnt, channel_names): 96 | # return the same cnt but with only right channels 97 | return cnt.pick_channels(channel_names) 98 | 99 | 100 | def standardize_cnt(cnt, standardize_mode=0): 101 | # computing frequencies 102 | sampling_freq = cnt.info['sfreq'] 103 | init_freq = 0.1 104 | stop_freq = sampling_freq / 2 - 0.1 105 | filt_order = 3 106 | axis = 0 107 | filtfilt = False 108 | 109 | # filtering DC and frequencies higher than the nyquist one 110 | cnt = mne_apply( 111 | lambda x: 112 | bandpass_cnt( 113 | data=x, 114 | low_cut_hz=init_freq, 115 | high_cut_hz=stop_freq, 116 | fs=sampling_freq, 117 | filt_order=filt_order, 118 | axis=axis, 119 | filtfilt=filtfilt 120 | ), 121 | cnt 122 | ) 123 | 124 | # removing mean and normalizing in 3 different ways 125 | if standardize_mode == 0: 126 | # x - mean 127 | cnt = mne_apply( 128 | lambda x: 129 | x - mean(x, axis=0, keepdims=True), 130 | cnt 131 | ) 132 | elif standardize_mode == 1: 133 | # (x - mean) / std 134 | cnt = mne_apply( 135 | lambda x: 136 | (x - mean(x, axis=0, keepdims=True)) / 137 | std(x, axis=0, keepdims=True), 138 | cnt 139 | ) 140 | elif standardize_mode == 2: 141 | # parsing to milli volt for numerical stability of next operations 142 | cnt = mne_apply(lambda a: a * 1e6, cnt) 143 | 144 | # applying exponential_running_standardize (Schirrmeister) 145 | cnt = mne_apply( 146 | lambda x: 147 | exponential_running_standardize( 148 | x.T, 149 | factor_new=1e-3, 150 | init_block_size=1000, 151 | eps=1e-4 152 | ).T, 153 | cnt 154 | ) 155 | return cnt 156 | 157 | 158 | def load_and_preprocess_data(data_dir, 159 | name_to_start_codes, 160 | channel_names, 161 | subject_id=1, 162 | resampling_freq=None, 163 | clean_ival_ms=(0, 4000), 164 | train_test_split=True, 165 | clean_on_all_channels=True, 166 | standardize_mode=None): 167 | # TODO: create here another get_data_files_paths function if you have a 168 | # different file configuration; in every case, file_paths must be a 169 | # list of paths to valid BBCI standard files 170 | # getting data paths 171 | file_paths = get_data_files_paths( 172 | data_dir, 173 | subject_id=subject_id, 174 | train_test_split=train_test_split 175 | ) 176 | 177 | # starting the loading routine 178 | print_manager('DATA LOADING ROUTINE FOR SUBJ ' + str(subject_id), 179 | 'double-dashed') 180 | print_manager('Loading continuous data...') 181 | 182 | # pre-allocating main cnt 183 | cnt = None 184 | 185 | # loading files and merging them 186 | for idx, current_path in enumerate(file_paths): 187 | current_cnt = load_cnt(file_path=current_path, 188 | channel_names=channel_names, 189 | clean_on_all_channels=clean_on_all_channels) 190 | # if the path is the first one... 191 | if idx is 0: 192 | # ...copying current_cnt as the main one, else... 193 | cnt = deepcopy(current_cnt) 194 | else: 195 | # merging current_cnt with the main one 196 | cnt = concatenate_raws_with_events([cnt, current_cnt]) 197 | print_manager('DONE!!', bottom_return=1) 198 | 199 | # getting clean_trial_mask 200 | print_manager('Getting clean trial mask...') 201 | clean_trial_mask = get_clean_trial_mask( 202 | cnt=cnt, 203 | name_to_start_codes=name_to_start_codes, 204 | clean_ival_ms=clean_ival_ms 205 | ) 206 | print_manager('DONE!!', bottom_return=1) 207 | 208 | # pick only right channels 209 | log.info('Picking only right channels...') 210 | cnt = pick_right_channels(cnt, channel_names) 211 | print_manager('DONE!!', bottom_return=1) 212 | 213 | # resample continuous data 214 | if resampling_freq is not None: 215 | log.info('Resampling continuous data...') 216 | cnt = resample_cnt( 217 | cnt, 218 | resampling_freq 219 | ) 220 | print_manager('DONE!!', bottom_return=1) 221 | 222 | # standardize continuous data 223 | if standardize_mode is not None: 224 | log.info('Standardizing continuous data...') 225 | log.info('Standardize mode: {}'.format(standardize_mode)) 226 | cnt = standardize_cnt(cnt=cnt, standardize_mode=standardize_mode) 227 | print_manager('DONE!!', 'last', bottom_return=1) 228 | 229 | return cnt, clean_trial_mask 230 | 231 | 232 | def ml_loader(data_dir, 233 | name_to_start_codes, 234 | channel_names, 235 | subject_id=1, 236 | resampling_freq=None, 237 | clean_ival_ms=(0, 4000), 238 | train_test_split=True, 239 | clean_on_all_channels=True, 240 | standardize_mode=None): 241 | outputs = load_and_preprocess_data( 242 | data_dir=data_dir, 243 | name_to_start_codes=name_to_start_codes, 244 | channel_names=channel_names, 245 | subject_id=subject_id, 246 | resampling_freq=resampling_freq, 247 | clean_ival_ms=clean_ival_ms, 248 | train_test_split=train_test_split, 249 | clean_on_all_channels=clean_on_all_channels, 250 | standardize_mode=standardize_mode 251 | ) 252 | return outputs[0], outputs[1] 253 | 254 | 255 | def dl_loader(data_dir, 256 | name_to_start_codes, 257 | channel_names, 258 | subject_id=1, 259 | resampling_freq=None, 260 | clean_ival_ms=(0, 4000), 261 | epoch_ival_ms=(-500, 4000), 262 | train_test_split=True, 263 | clean_on_all_channels=True, 264 | standardize_mode=0): 265 | # loading and pre-processing data 266 | cnt, clean_trial_mask = load_and_preprocess_data( 267 | data_dir=data_dir, 268 | name_to_start_codes=name_to_start_codes, 269 | channel_names=channel_names, 270 | subject_id=subject_id, 271 | resampling_freq=resampling_freq, 272 | clean_ival_ms=clean_ival_ms, 273 | train_test_split=train_test_split, 274 | clean_on_all_channels=clean_on_all_channels, 275 | standardize_mode=standardize_mode 276 | ) 277 | print_manager('EPOCHING AND CLEANING WITH MASK', 'double-dashed') 278 | 279 | # epoching continuous data (from RawArray to SignalAndTarget) 280 | print_manager('Epoching...') 281 | epo = create_signal_target_from_raw_mne( 282 | cnt, 283 | name_to_start_codes, 284 | epoch_ival_ms 285 | ) 286 | print_manager('DONE!!', bottom_return=1) 287 | 288 | # cleaning epoched signal with mask 289 | print_manager('cleaning with mask...') 290 | epo.X = epo.X[clean_trial_mask] 291 | epo.y = epo.y[clean_trial_mask] 292 | print_manager('DONE!!', 'last', bottom_return=1) 293 | 294 | # returning only the epoched signal 295 | return epo 296 | 297 | 298 | class CrossSubject(object): 299 | """ 300 | Nel momento in cui si crea una istanza di questa classe, essa caricherà 301 | tutti quanti gli id soggetto indicati in formato cnt con le relative 302 | maschere ed un nuovo array che ci dice tutti gli indici in cui iniziano 303 | i vari soggetti; inoltre, si andrà poi a specificare volta per volta il 304 | tipo di dato che si vorrà avere in memoria utilizzando un metodo parser, 305 | che andrà a sovrascrivere i dati nel nuovo formato. 306 | """ 307 | 308 | def __init__(self, 309 | data_dir, 310 | subject_ids, 311 | channel_names, 312 | name_to_start_codes, 313 | random_state=None, 314 | validation_frac=None, 315 | validation_size=None, 316 | resampling_freq=None, 317 | train_test_split=True, 318 | clean_ival_ms=(-500, 4000), 319 | epoch_ival_ms=(-500, 4000), 320 | clean_on_all_channels=True): 321 | # from input properties 322 | self.data_dir = data_dir 323 | self.subject_ids = subject_ids 324 | self.channel_names = channel_names 325 | self.name_to_start_codes = name_to_start_codes 326 | self.resampling_freq = resampling_freq 327 | self.train_test_split = train_test_split 328 | self.clean_ival_ms = clean_ival_ms 329 | self.epoch_ival_ms = epoch_ival_ms 330 | self.clean_on_all_channels = clean_on_all_channels 331 | 332 | # saving random state; if it is not specified, creating a 1234 one 333 | if random_state is None: 334 | self.random_state = RandomState(1234) 335 | else: 336 | self.random_state = random_state 337 | 338 | # other object properties 339 | self.data = None 340 | self.clean_trial_mask = [] 341 | self.subject_labels = None 342 | self.subject_indexes = [] 343 | self.folds = None 344 | 345 | # for fold specific data, creating a blank property 346 | self.fold_data = None 347 | self.fold_subject_labels = None 348 | 349 | # loading the first subject (to pre-allocate cnt array) 350 | # we are gonna pass standardize_mode=None, so the loading procedure 351 | # will not standardize data. They will be standardized at the very 352 | # end, when all data are loaded 353 | temp_cnt, temp_mask = load_and_preprocess_data( 354 | data_dir=self.data_dir, 355 | name_to_start_codes=self.name_to_start_codes, 356 | channel_names=self.channel_names, 357 | subject_id=self.subject_ids[0], 358 | resampling_freq=self.resampling_freq, 359 | clean_ival_ms=self.clean_ival_ms, 360 | train_test_split=self.train_test_split, 361 | clean_on_all_channels=self.clean_on_all_channels, 362 | standardize_mode=None 363 | ) 364 | 365 | # allocate the first subject_labels 366 | temp_labels = repeat(array([subject_ids[0]]), len(temp_mask)) 367 | 368 | # appending new indexes (only cleaned cnt will count!) 369 | last_non_zero_len = count_nonzero(self.clean_trial_mask) 370 | self.subject_indexes.append( 371 | [last_non_zero_len, last_non_zero_len + count_nonzero(temp_mask)] 372 | ) 373 | 374 | # merging cnt, mask and labels (in this case assigning) 375 | self.data = temp_cnt 376 | self.clean_trial_mask = temp_mask 377 | self.subject_labels = temp_labels 378 | 379 | # creating iterable object from subject_ids and skipping the first one 380 | iter_subjects = iter(self.subject_ids) 381 | next(iter_subjects) 382 | 383 | # loading all others cnt data and concatenating them 384 | for current_subject in iter_subjects: 385 | # loading current subject cnt and mask 386 | temp_cnt, temp_mask = load_and_preprocess_data( 387 | subject_id=current_subject, # here the current subject! 388 | data_dir=self.data_dir, 389 | name_to_start_codes=self.name_to_start_codes, 390 | channel_names=self.channel_names, 391 | resampling_freq=self.resampling_freq, 392 | clean_ival_ms=self.clean_ival_ms, 393 | train_test_split=self.train_test_split, 394 | clean_on_all_channels=self.clean_on_all_channels, 395 | standardize_mode=None 396 | ) 397 | 398 | # create the subject_labels for this subject 399 | temp_labels = repeat(array([current_subject]), len(temp_mask)) 400 | 401 | # appending new indexes (only cleaned cnt will count!) 402 | last_non_zero_len = count_nonzero(self.clean_trial_mask) 403 | self.subject_indexes.append( 404 | [last_non_zero_len, 405 | last_non_zero_len + count_nonzero(temp_mask)] 406 | ) 407 | 408 | # merging cnt and mask 409 | self.data = concatenate_raws_with_events([self.data, temp_cnt]) 410 | self.clean_trial_mask = \ 411 | concatenate([self.clean_trial_mask, temp_mask]) 412 | self.subject_labels = \ 413 | concatenate([self.subject_labels, temp_labels]) 414 | 415 | # computing validation_frac and validation_size 416 | if validation_size is None: 417 | if validation_frac is None: 418 | self.validation_frac = 0 419 | self.validation_size = 0 420 | else: 421 | self.validation_frac = validation_frac 422 | self.validation_size = \ 423 | int(floor(self.n_trials * self.validation_frac)) 424 | else: 425 | self.validation_size = validation_size 426 | self.validation_frac = self.validation_size / self.n_trials 427 | 428 | def parser(self, output_format, leave_subj=None, parsing_type=0): 429 | """ 430 | HOW DOES IT WORK? 431 | ----------------- 432 | if parsing_type is 0 then epoched signal will be saved in 433 | fold_data; if parsing_type is 1 then the cnt signal will be replaced 434 | with the epoched one 435 | """ 436 | if output_format is 'epo': 437 | self.cnt_to_epo(parsing_type=parsing_type) 438 | elif output_format is 'EEGDataset': 439 | self.cnt_to_epo(parsing_type=parsing_type) 440 | self.epo_to_dataset(leave_subj=leave_subj, 441 | parsing_type=parsing_type) 442 | 443 | def cnt_to_epo(self, parsing_type): 444 | # checking if data is cnt; if not, the method will not work 445 | if isinstance(self.data, RawArray): 446 | """ 447 | WHATS GOING ON HERE? 448 | -------------------- 449 | If parsing_type is 0, then there will be a 'soft parsing 450 | routine', data will parsed and stored in fold_data instead of 451 | in the main data property 452 | """ 453 | if parsing_type == 0: 454 | # parsing from cnt to epoch 455 | print_manager('Parsing cnt signal to epoched one...') 456 | self.fold_data = create_signal_target_from_raw_mne( 457 | self.data, 458 | self.name_to_start_codes, 459 | self.epoch_ival_ms 460 | ) 461 | print_manager('DONE!!', bottom_return=1) 462 | 463 | # cleaning signal and labels with mask 464 | print_manager('Cleaning epoched signal with mask...') 465 | self.fold_data.X = self.fold_data.X[self.clean_trial_mask] 466 | self.fold_data.y = self.fold_data.y[self.clean_trial_mask] 467 | self.fold_subject_labels = \ 468 | self.subject_labels[self.clean_trial_mask] 469 | print_manager('DONE!!', bottom_return=1) 470 | elif parsing_type == 1: 471 | """ 472 | WHATS GOING ON HERE? 473 | -------------------- 474 | If parsing_type is 1, then the epoched signal will replace 475 | the original one in the data property 476 | """ 477 | print_manager('Parsing cnt signal to epoched one...') 478 | self.data = create_signal_target_from_raw_mne( 479 | self.data, 480 | self.name_to_start_codes, 481 | self.epoch_ival_ms 482 | ) 483 | print_manager('DONE!!', bottom_return=1) 484 | 485 | # cleaning signal and labels 486 | print_manager('Cleaning epoched signal with mask...') 487 | self.data.X = self.data.X[self.clean_trial_mask] 488 | self.data.y = self.data.y[self.clean_trial_mask] 489 | self.subject_labels = \ 490 | self.subject_labels[self.clean_trial_mask] 491 | print_manager('DONE!!', bottom_return=1) 492 | else: 493 | raise ValueError( 494 | 'parsing_type {} not supported.'.format(parsing_type) 495 | ) 496 | 497 | # now that we have an epoched signal, we can already create 498 | # folds for cross-subject validation 499 | self.create_balanced_folds() 500 | 501 | def epo_to_dataset(self, leave_subj, parsing_type=0): 502 | print_manager('FOLD ALL BUT ' + str(leave_subj), 'double-dashed') 503 | print_manager('Creating current fold...') 504 | 505 | print_manager('DONE!!', bottom_return=1) 506 | print_manager('Parsing epoched signal to EEGDataset...') 507 | if parsing_type is 0: 508 | self.fold_data = CrossValidation.create_dataset_static( 509 | self.fold_data, self.folds[leave_subj - 1] 510 | ) 511 | elif parsing_type is 1: 512 | self.fold_data = CrossValidation.create_dataset_static( 513 | self.data, self.folds[leave_subj - 1] 514 | ) 515 | else: 516 | raise ValueError( 517 | 'parsing_type {} not supported.'.format(parsing_type) 518 | ) 519 | print_manager('DONE!!', bottom_return=1) 520 | print_manager('We obtained a ' + str(self.fold_data)) 521 | print_manager('DATA READY!!', 'last', bottom_return=1) 522 | 523 | def create_balanced_folds(self): 524 | # pre-allocating folds 525 | self.folds = [] 526 | for subj_idx, subj_idxs in enumerate(self.subject_indexes): 527 | # getting current test_idxs (all a subject trials) 528 | test_idxs = arange(subj_idxs[0], subj_idxs[1]) 529 | 530 | # getting train_idxs as all but the current subject 531 | train_idxs = setdiff1d(arange(self.n_trials), test_idxs) 532 | 533 | # pre-allocating valid_idxs 534 | valid_idxs = array([], dtype='int') 535 | 536 | # if no validation set is required... 537 | if self.validation_frac == 0: 538 | # setting valid_idxs to None, else... 539 | valid_idxs = None 540 | else: 541 | # ...determining number of splits for this train/validation set 542 | n_splits = int(floor(self.validation_frac * 100)) 543 | 544 | # getting StratifiesKFold object 545 | skf = StratifiedKFold(n_splits=n_splits, 546 | random_state=self.random_state, 547 | shuffle=True) 548 | 549 | # cycling on subject in the train fold 550 | for c_subj_idx, c_subj_idxs in enumerate(self.subject_indexes): 551 | if c_subj_idx == subj_idx: 552 | # nothing to do 553 | pass 554 | else: 555 | # splitting first subject train / valid 556 | X, y = self._get_subject_data(c_subj_idx) 557 | 558 | # get batch from StratifiedKFold object 559 | for c_train_idxs, c_valid_idxs in skf.split(X=X, y=y): 560 | # referring c_train_idxs and c_valid_idxs 561 | c_train_idxs += c_subj_idxs[0] 562 | c_valid_idxs += c_subj_idxs[0] 563 | 564 | # remove this batch indexes from train_idxs 565 | train_idxs = setdiff1d(train_idxs, c_valid_idxs) 566 | 567 | # adding this batch indexes to valid_idxs 568 | valid_idxs = concatenate([valid_idxs, 569 | c_valid_idxs]) 570 | # all is done for this subject!! Breaking cycle 571 | break 572 | 573 | # appending new fold 574 | self.folds.append( 575 | { 576 | 'train': train_idxs, 577 | 'valid': valid_idxs, 578 | 'test': test_idxs 579 | } 580 | ) 581 | 582 | def _get_subject_data(self, subj_idx): 583 | init = self.subject_indexes[subj_idx][0] 584 | stop = self.subject_indexes[subj_idx][1] 585 | ival = arange(init, stop) 586 | if isinstance(self.fold_data, SignalAndTarget): 587 | return self.fold_data.X[ival], self.fold_data.y[ival] 588 | elif isinstance(self.data, SignalAndTarget): 589 | return self.data.X[ival], self.data.y[ival] 590 | else: 591 | raise ValueError('You are trying to get epoched data but you ' 592 | 'still have to parse cnt data.') 593 | 594 | @property 595 | def n_trials(self): 596 | return count_nonzero(self.clean_trial_mask) 597 | -------------------------------------------------------------------------------- /hgdecode/models.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.models import Model 3 | from keras.layers import Dense 4 | from keras.layers import Input 5 | from keras.layers import Conv2D 6 | from keras.layers import Permute 7 | from keras.layers import Dropout 8 | from keras.layers import Flatten 9 | from keras.layers import Activation 10 | from keras.layers import MaxPooling2D 11 | from keras.layers import SeparableConv2D 12 | from keras.layers import DepthwiseConv2D 13 | from keras.layers import SpatialDropout2D 14 | from keras.layers import AveragePooling2D 15 | from keras.layers import BatchNormalization 16 | from keras.constraints import max_norm 17 | from keras.regularizers import l1_l2 18 | 19 | 20 | # TODO: define pool_size, strides and other parameters as experiment 21 | # properties, tunable from the user in the main script; then passing the 22 | # entire experiment to the model function constructor or each useful 23 | # parameter individually 24 | 25 | 26 | # %% DEEP CONV NET 27 | def DeepConvNet(n_classes=4, 28 | n_channels=64, 29 | n_samples=256, 30 | dropout_rate=0.5): 31 | """ Keras implementation of the Deep Convolutional Network as described in 32 | Schirrmeister et. al. (2017), Human Brain Mapping. 33 | This implementation assumes the input is a 2-second EEG signal sampled at 34 | 128Hz, as opposed to signals sampled at 250Hz as described in the original 35 | paper. We also perform temporal convolutions of length (1, 5) as opposed 36 | to (1, 10) due to this sampling rate difference. 37 | Note that we use the max_norm constraint on all convolutional layers, as 38 | well as the classification layer. We also change the defaults for the 39 | BatchNormalization layer. We used this based on a personal communication 40 | with the original authors. 41 | ours original paper 42 | pool_size 1, 2 1, 3 43 | strides 1, 2 1, 3 44 | conv filters 1, 5 1, 10 45 | Note that this implementation has not been verified by the original 46 | authors. 47 | """ 48 | 49 | # start the model 50 | input_main = Input((1, n_channels, n_samples)) 51 | block1 = Conv2D(25, (1, 10), 52 | input_shape=(1, n_channels, n_samples), 53 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(input_main) 54 | block1 = Conv2D(25, (n_channels, 1), 55 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(block1) 56 | block1 = BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)(block1) 57 | block1 = Activation('elu')(block1) 58 | block1 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1) 59 | block1 = Dropout(dropout_rate)(block1) 60 | 61 | block2 = Conv2D(50, (1, 10), 62 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(block1) 63 | block2 = BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)(block2) 64 | block2 = Activation('elu')(block2) 65 | block2 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block2) 66 | block2 = Dropout(dropout_rate)(block2) 67 | 68 | block3 = Conv2D(100, (1, 10), 69 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(block2) 70 | block3 = BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)(block3) 71 | block3 = Activation('elu')(block3) 72 | block3 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block3) 73 | block3 = Dropout(dropout_rate)(block3) 74 | 75 | block4 = Conv2D(200, (1, 10), 76 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(block3) 77 | block4 = BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)(block4) 78 | block4 = Activation('elu')(block4) 79 | block4 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block4) 80 | block4 = Dropout(dropout_rate)(block4) 81 | 82 | flatten = Flatten()(block4) 83 | 84 | dense = Dense(n_classes, kernel_constraint=max_norm(0.5))(flatten) 85 | softmax = Activation('softmax')(dense) 86 | 87 | return Model(inputs=input_main, outputs=softmax) 88 | 89 | 90 | # %% DEEP CONV NET 500 Hz 91 | def DeepConvNet_500Hz(n_classes=4, 92 | n_channels=64, 93 | n_samples=256, 94 | dropout_rate=0.5): 95 | """ 96 | # TODO: description for this model 97 | """ 98 | # input 99 | input_main = Input((1, n_channels, n_samples)) 100 | 101 | # block1 102 | block1 = Conv2D(25, (1, 20), 103 | # bias_initializer='truncated_normal', 104 | # kernel_initializer='he_normal', 105 | # kernel_regularizer=l2(0.0001), 106 | input_shape=(1, n_channels, n_samples), 107 | kernel_constraint=max_norm(2., axis=(0, 1, 2)) 108 | )(input_main) 109 | block1 = Conv2D(25, (n_channels, 1), 110 | # bias_initializer='truncated_normal', 111 | # kernel_initializer='he_normal', 112 | # kernel_regularizer=l2(0.0001), 113 | kernel_constraint=max_norm(2., axis=(0, 1, 2)) 114 | )(block1) 115 | block1 = BatchNormalization(axis=1)(block1) 116 | block1 = Activation('elu')(block1) 117 | block1 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block1) 118 | block1 = Dropout(dropout_rate)(block1) 119 | 120 | # block2 121 | block2 = Conv2D(50, (1, 20), 122 | # bias_initializer='truncated_normal', 123 | # kernel_initializer='he_normal', 124 | # kernel_regularizer=l2(0.0001), 125 | kernel_constraint=max_norm(2., axis=(0, 1, 2)) 126 | )(block1) 127 | block2 = BatchNormalization(axis=1)(block2) 128 | block2 = Activation('elu')(block2) 129 | block2 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block2) 130 | block2 = Dropout(dropout_rate)(block2) 131 | 132 | # block3 133 | block3 = Conv2D(100, (1, 20), 134 | # bias_initializer='truncated_normal', 135 | # kernel_initializer='he_normal', 136 | # kernel_regularizer=l2(0.0001), 137 | kernel_constraint=max_norm(2., axis=(0, 1, 2)) 138 | )(block2) 139 | block3 = BatchNormalization(axis=1)(block3) 140 | block3 = Activation('elu')(block3) 141 | block3 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block3) 142 | block3 = Dropout(dropout_rate)(block3) 143 | 144 | # block4 145 | block4 = Conv2D(200, (1, 20), 146 | # bias_initializer='truncated_normal', 147 | # kernel_initializer='he_normal', 148 | # kernel_regularizer=l2(0.0001), 149 | kernel_constraint=max_norm(2., axis=(0, 1, 2)) 150 | )(block3) 151 | block4 = BatchNormalization(axis=1)(block4) 152 | block4 = Activation('elu')(block4) 153 | block4 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block4) 154 | block4 = Dropout(dropout_rate)(block4) 155 | 156 | # flatten 157 | flatten = Flatten()(block4) 158 | 159 | # another dense one 160 | # dense = Dense(128, bias_initializer='truncated_normal', 161 | # kernel_initializer='he_normal', 162 | # kernel_regularizer=l2(0.001), 163 | # kernel_constraint=max_norm(0.5))(flatten) 164 | # dense = Activation('elu')(dense) 165 | # dense = Dropout(dropout_rate)(dense) 166 | 167 | # dense 168 | dense = Dense(n_classes, 169 | # bias_initializer='truncated_normal', 170 | # kernel_initializer='truncated_normal', 171 | kernel_constraint=max_norm(0.5) 172 | )(flatten) 173 | softmax = Activation('softmax')(dense) 174 | 175 | # returning the model 176 | return Model(inputs=input_main, outputs=softmax) 177 | 178 | 179 | # %% DEEP CONV NET DAVIDE 180 | def DeepConvNet_Davide(n_classes=4, 181 | n_channels=64, 182 | n_samples=256, 183 | dropout_rate=0.5): 184 | """ 185 | TODO: a description for this model 186 | :param n_classes: 187 | :param n_channels: 188 | :param n_samples: 189 | :param dropout_rate: 190 | :return: 191 | """ 192 | # start the model 193 | input_main = Input((1, n_channels, n_samples)) 194 | block1 = Conv2D(25, (1, 10), 195 | # bias_initializer='truncated_normal', 196 | # kernel_initializer='he_normal', 197 | # kernel_regularizer=l2(0.0001), 198 | input_shape=(1, n_channels, n_samples), 199 | kernel_constraint=max_norm(2., axis=(0, 1, 2)) 200 | )(input_main) 201 | block1 = Conv2D(25, (n_channels, 1), 202 | # bias_initializer='truncated_normal', 203 | # kernel_initializer='he_normal', 204 | # kernel_regularizer=l2(0.0001), 205 | kernel_constraint=max_norm(2., axis=(0, 1, 2)) 206 | )(block1) 207 | block1 = BatchNormalization(axis=1)(block1) 208 | block1 = Activation('elu')(block1) 209 | 210 | block1 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1) 211 | block1 = Dropout(dropout_rate)(block1) 212 | 213 | block2 = Conv2D(50, (1, 10), 214 | # bias_initializer='truncated_normal', 215 | # kernel_initializer='he_normal', 216 | # kernel_regularizer=l2(0.0001), 217 | kernel_constraint=max_norm(2., axis=(0, 1, 2)) 218 | )(block1) 219 | block2 = BatchNormalization(axis=1)(block2) 220 | block2 = Activation('elu')(block2) 221 | block2 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block2) 222 | block2 = Dropout(dropout_rate)(block2) 223 | 224 | block3 = Conv2D(100, (1, 10), 225 | # bias_initializer='truncated_normal', 226 | # kernel_initializer='he_normal', 227 | # kernel_regularizer=l2(0.0001), 228 | kernel_constraint=max_norm(2., axis=(0, 1, 2)) 229 | )(block2) 230 | block3 = BatchNormalization(axis=1)(block3) 231 | block3 = Activation('elu')(block3) 232 | 233 | block3 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block3) 234 | block3 = Dropout(dropout_rate)(block3) 235 | 236 | block4 = Conv2D(200, (1, 10), 237 | # bias_initializer='truncated_normal', 238 | # kernel_initializer='he_normal', 239 | # kernel_regularizer=l2(0.0001), 240 | kernel_constraint=max_norm(2., axis=(0, 1, 2)) 241 | )(block3) 242 | block4 = BatchNormalization(axis=1)(block4) 243 | block4 = Activation('elu')(block4) 244 | 245 | block4 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block4) 246 | block4 = Dropout(dropout_rate)(block4) 247 | 248 | flatten = Flatten()(block4) 249 | # dense = Dense(128, bias_initializer='truncated_normal', 250 | # kernel_initializer='he_normal', 251 | # kernel_regularizer=l2(0.001), 252 | # kernel_constraint=max_norm(0.5))(flatten) 253 | # dense = Activation('elu')(dense) 254 | # dense = Dropout(dropout_rate)(dense) 255 | dense = Dense(n_classes, 256 | # bias_initializer='truncated_normal', 257 | # kernel_initializer='truncated_normal', 258 | kernel_constraint=max_norm(0.5) 259 | )(flatten) 260 | softmax = Activation('softmax')(dense) 261 | 262 | return Model(inputs=input_main, outputs=softmax) 263 | 264 | 265 | # %% SHALLOW CONV NET 266 | def square(x): 267 | return K.square(x) 268 | 269 | 270 | def log(x): 271 | return K.log(K.clip(x, min_value=1e-7, max_value=10000)) 272 | 273 | 274 | def ShallowConvNet(n_classes, 275 | n_channels=64, 276 | n_samples=128, 277 | dropout_rate=0.5): 278 | """ Keras implementation of the Shallow Convolutional Network as described 279 | in Schirrmeister et. al. (2017), Human Brain Mapping. 280 | 281 | Assumes the input is a 2-second EEG signal sampled at 128Hz. Note that in 282 | the original paper, they do temporal convolutions of length 25 for EEG 283 | data sampled at 250Hz. We instead use length 13 since the sampling rate is 284 | roughly half of the 250Hz which the paper used. The pool_size and stride 285 | in later layers is also approximately half of what is used in the paper. 286 | 287 | Note that we use the max_norm constraint on all convolutional layers, as 288 | well as the classification layer. We also change the defaults for the 289 | BatchNormalization layer. We used this based on a personal communication 290 | with the original authors. 291 | 292 | ours original paper 293 | pool_size 1, 35 1, 75 294 | strides 1, 7 1, 15 295 | conv filters 1, 13 1, 25 296 | 297 | Note that this implementation has not been verified by the original 298 | authors. We do note that this implementation reproduces the results in the 299 | original paper with minor deviations. 300 | """ 301 | 302 | # start the model 303 | input_main = Input((1, n_channels, n_samples)) 304 | block1 = Conv2D(40, (1, 13), 305 | input_shape=(1, n_channels, n_samples), 306 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(input_main) 307 | block1 = Conv2D(40, (n_channels, 1), use_bias=False, 308 | kernel_constraint=max_norm(2., axis=(0, 1, 2)))(block1) 309 | block1 = BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)(block1) 310 | block1 = Activation(square)(block1) 311 | block1 = AveragePooling2D(pool_size=(1, 35), strides=(1, 7))(block1) 312 | block1 = Activation(log)(block1) 313 | block1 = Dropout(dropout_rate)(block1) 314 | flatten = Flatten()(block1) 315 | dense = Dense(n_classes, kernel_constraint=max_norm(0.5))(flatten) 316 | softmax = Activation('softmax')(dense) 317 | 318 | return Model(inputs=input_main, outputs=softmax) 319 | 320 | 321 | # %% EEG NET 322 | def EEGNet(n_classes, 323 | n_channels=64, 324 | n_samples=128, 325 | dropout_rate=0.25, 326 | kernel_length=64, 327 | F1=4, 328 | D=2, 329 | F2=8, 330 | norm_rate=0.25, 331 | dropout_type='Dropout'): 332 | """ Keras Implementation of EEGNet 333 | http://iopscience.iop.org/article/10.1088/1741-2552/aace8c/meta 334 | Note that this implements the newest version of EEGNet and NOT the earlier 335 | version (version v1 and v2 on arxiv). We strongly recommend using this 336 | architecture as it performs much better and has nicer properties than 337 | our earlier version. For example: 338 | 339 | 1. Depthwise Convolutions to learn spatial filters within a 340 | temporal convolution. The use of the depth_multiplier option maps 341 | exactly to the number of spatial filters learned within a temporal 342 | filter. This matches the setup of algorithms like FBCSP which learn 343 | spatial filters within each filter in a filter-bank. This also limits 344 | the number of free parameters to fit when compared to a fully-connected 345 | convolution. 346 | 347 | 2. Separable Convolutions to learn how to optimally combine spatial 348 | filters across temporal bands. Separable Convolutions are Depthwise 349 | Convolutions followed by (1x1) Pointwise Convolutions. 350 | 351 | 352 | While the original paper used Dropout, we found that SpatialDropout2D 353 | sometimes produced slightly better results for classification of ERP 354 | signals. However, SpatialDropout2D significantly reduced performance 355 | on the Oscillatory dataset (SMR, BCI-IV Dataset 2A). We recommend using 356 | the default Dropout in most cases. 357 | 358 | Assumes the input signal is sampled at 128Hz. If you want to use this model 359 | for any other sampling rate you will need to modify the lengths of temporal 360 | kernels and average pooling size in blocks 1 and 2 as needed (double the 361 | kernel lengths for double the sampling rate, etc). Note that we haven't 362 | tested the model performance with this rule so this may not work well. 363 | 364 | The model with default parameters gives the EEGNet-4,2 model as discussed 365 | in the paper. This model should do pretty well in general, although as the 366 | paper discussed the EEGNet-8,2 (with 8 temporal kernels and 2 spatial 367 | filters per temporal kernel) can do slightly better on the SMR dataset. 368 | Other variations that we found to work well are EEGNet-4,1 and EEGNet-8,1. 369 | We set F2 = F1 * D (number of input filters = number of output filters) for 370 | the SeparableConv2D layer. We haven't extensively tested other values of 371 | this parameter (say, F2 < F1 * D for compressed learning, and F2 > F1 * D 372 | for overcomplete). We believe the main parameters to focus on are F1 and D. 373 | Inputs: 374 | 375 | n_classes : int, number of classes to classify 376 | n_channels : number of channels 377 | n_samples : number of time points in the EEG data 378 | dropout_rate : dropout fraction 379 | kernel_length : length of temporal convolution in first layer. We found 380 | that setting this to be half the sampling rate worked 381 | well in practice. For the SMR dataset in particular 382 | since the data was high-passed at 4Hz we used a kernel 383 | length of 32. 384 | F1, F2 : number of temporal filters (F1) and number of pointwise 385 | filters (F2) to learn. Default: F1 = 4, F2 = F1 * D. 386 | D : number of spatial filters to learn within each temporal 387 | convolution. Default: D = 2 388 | dropout_type : Either SpatialDropout2D or Dropout, passed as a string. 389 | """ 390 | 391 | if dropout_type == 'SpatialDropout2D': 392 | dropout_type = SpatialDropout2D 393 | elif dropout_type == 'Dropout': 394 | dropout_type = Dropout 395 | else: 396 | raise ValueError('dropout_type must be one of SpatialDropout2D ' 397 | 'or Dropout, passed as a string.') 398 | 399 | input1 = Input(shape=(1, n_channels, n_samples)) 400 | 401 | ################################################################## 402 | block1 = Conv2D(F1, (1, kernel_length), padding='same', 403 | input_shape=(1, n_channels, n_samples), 404 | use_bias=False)(input1) 405 | block1 = BatchNormalization(axis=1)(block1) 406 | block1 = DepthwiseConv2D((n_channels, 1), use_bias=False, 407 | depth_multiplier=D, 408 | depthwise_constraint=max_norm(1.))(block1) 409 | block1 = BatchNormalization(axis=1)(block1) 410 | block1 = Activation('elu')(block1) 411 | block1 = AveragePooling2D((1, 4))(block1) 412 | block1 = dropout_type(dropout_rate)(block1) 413 | 414 | block2 = SeparableConv2D(F2, (1, 16), 415 | use_bias=False, padding='same')(block1) 416 | block2 = BatchNormalization(axis=1)(block2) 417 | block2 = Activation('elu')(block2) 418 | block2 = AveragePooling2D((1, 8))(block2) 419 | block2 = dropout_type(dropout_rate)(block2) 420 | 421 | flatten = Flatten(name='flatten')(block2) 422 | 423 | dense = Dense(n_classes, name='dense', 424 | kernel_constraint=max_norm(norm_rate))(flatten) 425 | softmax = Activation('softmax', name='softmax')(dense) 426 | 427 | return Model(inputs=input1, outputs=softmax) 428 | 429 | 430 | # %% EEG NET SSVEP 431 | def EEGNet_SSVEP(n_classes=12, 432 | n_channels=8, 433 | n_samples=256, 434 | dropout_rate=0.5, 435 | kernel_length=256, 436 | F1=96, 437 | D=1, 438 | F2=96, 439 | dropout_type='Dropout'): 440 | """ SSVEP Variant of EEGNet, as used in [1]. 441 | Inputs: 442 | 443 | n_classes : int, number of classes to classify 444 | n_channels : number of channels 445 | n_samples : number of time points in the EEG data 446 | dropout_rate : dropout fraction 447 | kernel_length : length of temporal convolution in first layer 448 | F1, F2 : number of temporal filters (F1) and number of pointwise 449 | filters (F2) to learn. 450 | D : number of spatial filters to learn within each temporal 451 | convolution. 452 | dropout_type : Either SpatialDropout2D or Dropout, passed as a string. 453 | 454 | 455 | [1]. Waytowich, N. et. al. (2018). Compact Convolutional Neural Networks 456 | for Classification of Asynchronous Steady-State Visual Evoked Potentials. 457 | Journal of Neural Engineering vol. 15(6). 458 | http://iopscience.iop.org/article/10.1088/1741-2552/aae5d8 459 | """ 460 | 461 | if dropout_type == 'SpatialDropout2D': 462 | dropout_type = SpatialDropout2D 463 | elif dropout_type == 'Dropout': 464 | dropout_type = Dropout 465 | else: 466 | raise ValueError('dropout_type must be one of SpatialDropout2D ' 467 | 'or Dropout, passed as a string.') 468 | 469 | input1 = Input(shape=(1, n_channels, n_samples)) 470 | 471 | ################################################################## 472 | block1 = Conv2D(F1, (1, kernel_length), padding='same', 473 | input_shape=(1, n_channels, n_samples), 474 | use_bias=False)(input1) 475 | block1 = BatchNormalization(axis=1)(block1) 476 | block1 = DepthwiseConv2D((n_channels, 1), use_bias=False, 477 | depth_multiplier=D, 478 | depthwise_constraint=max_norm(1.))(block1) 479 | block1 = BatchNormalization(axis=1)(block1) 480 | block1 = Activation('elu')(block1) 481 | block1 = AveragePooling2D((1, 4))(block1) 482 | block1 = dropout_type(dropout_rate)(block1) 483 | 484 | block2 = SeparableConv2D(F2, (1, 16), 485 | use_bias=False, padding='same')(block1) 486 | block2 = BatchNormalization(axis=1)(block2) 487 | block2 = Activation('elu')(block2) 488 | block2 = AveragePooling2D((1, 8))(block2) 489 | block2 = dropout_type(dropout_rate)(block2) 490 | 491 | flatten = Flatten(name='flatten')(block2) 492 | 493 | dense = Dense(n_classes, name='dense')(flatten) 494 | softmax = Activation('softmax', name='softmax')(dense) 495 | 496 | return Model(inputs=input1, outputs=softmax) 497 | 498 | 499 | # %% EEG NET OLD 500 | def EEGNet_old(n_classes, 501 | n_channels=64, 502 | n_samples=128, 503 | regRate=0.0001, 504 | dropout_rate=0.25, 505 | kernels=None, 506 | strides=(2, 4)): 507 | """ Keras Implementation of EEGNet_v1 (https://arxiv.org/abs/1611.08024v2) 508 | This model is the original EEGNet model proposed on arxiv 509 | https://arxiv.org/abs/1611.08024v2 510 | 511 | with a few modifications: we use striding instead of max-pooling as this 512 | helped slightly in classification performance while also providing a 513 | computational speed-up. 514 | 515 | Note that we no longer recommend the use of this architecture, as the new 516 | version of EEGNet performs much better overall and has nicer properties. 517 | 518 | Inputs: 519 | 520 | n_classes : total number of final categories 521 | n_channels : number of EEG channels 522 | n_samples : number of EEG time points 523 | regRate : regularization rate for L1 and L2 regularizations 524 | dropout_rate : dropout fraction 525 | kernels : the 2nd and 3rd layer kernel dimensions (default is 526 | the [2, 32] x [8, 4] configuration) 527 | strides : the stride size (note that this replaces the max-pool 528 | used in the original paper) 529 | 530 | """ 531 | # fixing PEP8 mutable input issue 532 | if kernels is None: 533 | kernels = [(2, 32), (8, 4)] 534 | 535 | # start the model 536 | input_main = Input((1, n_channels, n_samples)) 537 | layer1 = Conv2D(16, (n_channels, 1), 538 | input_shape=(1, n_channels, n_samples), 539 | kernel_regularizer=l1_l2(l1=regRate, l2=regRate))( 540 | input_main) 541 | layer1 = BatchNormalization(axis=1)(layer1) 542 | layer1 = Activation('elu')(layer1) 543 | layer1 = Dropout(dropout_rate)(layer1) 544 | 545 | permute_dims = 2, 1, 3 546 | permute1 = Permute(permute_dims)(layer1) 547 | 548 | layer2 = Conv2D(4, kernels[0], padding='same', 549 | kernel_regularizer=l1_l2(l1=0.0, l2=regRate), 550 | strides=strides)(permute1) 551 | layer2 = BatchNormalization(axis=1)(layer2) 552 | layer2 = Activation('elu')(layer2) 553 | layer2 = Dropout(dropout_rate)(layer2) 554 | 555 | layer3 = Conv2D(4, kernels[1], padding='same', 556 | kernel_regularizer=l1_l2(l1=0.0, l2=regRate), 557 | strides=strides)(layer2) 558 | layer3 = BatchNormalization(axis=1)(layer3) 559 | layer3 = Activation('elu')(layer3) 560 | layer3 = Dropout(dropout_rate)(layer3) 561 | 562 | flatten = Flatten(name='flatten')(layer3) 563 | 564 | dense = Dense(n_classes, name='dense')(flatten) 565 | softmax = Activation('softmax', name='softmax')(dense) 566 | 567 | return Model(inputs=input_main, outputs=softmax) 568 | -------------------------------------------------------------------------------- /hgdecode/signalproc.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | import scipy as sp 5 | 6 | from braindecode.datautil.signal_target import SignalAndTarget 7 | from braindecode.datautil.signalproc import bandpass_cnt 8 | from braindecode.mne_ext.signalproc import mne_apply 9 | 10 | 11 | def bandpass_mne(cnt, low_cut_hz, high_cut_hz, filt_order=3, axis=0): 12 | return mne_apply(lambda data: bandpass_cnt( 13 | data.T, low_cut_hz, high_cut_hz, fs=cnt.info['sfreq'], 14 | filt_order=filt_order, axis=axis).T, 15 | cnt) 16 | 17 | 18 | def select_trials(dataset, inds): 19 | if hasattr(dataset.X, 'ndim'): 20 | # numpy array 21 | new_X = np.array(dataset.X)[inds] 22 | else: 23 | # list 24 | new_X = [dataset.X[i] for i in inds] 25 | new_y = np.asarray(dataset.y)[inds] 26 | return SignalAndTarget(new_X, new_y) 27 | 28 | 29 | def select_classes_cnt(cnt, class_numbers): 30 | cnt = deepcopy(cnt) 31 | events = cnt.info['events'] 32 | new_events = [ev for ev in events if 33 | (ev[2] - ev[1]) in class_numbers] 34 | cnt.info['events'] = np.array(new_events) 35 | return cnt 36 | 37 | 38 | def select_classes(dataset, class_numbers): 39 | wanted_inds = [i_trial for i_trial, y in enumerate(dataset.y) 40 | if y in class_numbers] 41 | return select_trials(dataset, wanted_inds) 42 | 43 | 44 | def select_trials_cnt(cnt, inds): 45 | cnt = deepcopy(cnt) 46 | assert np.all( 47 | [i in np.arange(len(cnt.info['events'])) for i in inds]) 48 | events = cnt.info['events'] 49 | new_events = [ev for i_trial, ev in enumerate(events) if 50 | i_trial in inds] 51 | cnt.info['events'] = np.array(new_events) 52 | return cnt 53 | 54 | 55 | def concatenate_channels(datasets): 56 | all_X = [dataset.X for dataset in datasets] 57 | new_X = np.concatenate(all_X, axis=1) 58 | new_y = datasets[0].y 59 | for dataset in datasets: 60 | assert np.array_equal(dataset.y, new_y) 61 | return SignalAndTarget(new_X, new_y) 62 | 63 | 64 | def extract_all_start_codes(name_to_start_codes): 65 | all_start_codes = [] 66 | for val in name_to_start_codes.values(): 67 | if hasattr(val, '__len__'): 68 | all_start_codes.extend(val) 69 | else: 70 | all_start_codes.append(val) 71 | return all_start_codes 72 | 73 | 74 | def calculate_csp(epo, classes=None, average_trial_covariance=False): 75 | """Calculate the Common Spatial Pattern (CSP) for two classes. 76 | Now with pattern computation as in matlab bbci toolbox 77 | https://github.com/bbci/bbci_public/blob/c7201e4e42f873cced2e068c6cbb3780a8f8e9ec/processing/proc_csp.m#L112 78 | 79 | This method calculates the CSP and the corresponding filters. Use 80 | the columns of the patterns and filters. 81 | Examples 82 | -------- 83 | Calculate the CSP for the first two classes:: 84 | >> w, a, d = calculate_csp(epo) 85 | >> # Apply the first two and the last two columns of the sorted 86 | >> # filter to the data 87 | >> filtered = apply_spatial_filter(epo, w[:, [0, 1, -2, -1]]) 88 | >> # You'll probably want to get the log-variance along the time 89 | >> # axis, this should result in four numbers (one for each 90 | >> # channel) 91 | >> filtered = np.log(np.var(filtered, 0)) 92 | Select two classes manually:: 93 | >> w, a, d = calculate_csp(epo, [2, 5]) 94 | Parameters 95 | ---------- 96 | epo : epoched Data object 97 | this method relies on the ``epo`` to have three dimensions in 98 | the following order: class, time, channel 99 | classes : list of two ints, optional 100 | If ``None`` the first two different class indices found in 101 | ``epo.axes[0]`` are chosen automatically otherwise the class 102 | indices can be manually chosen by setting ``classes`` 103 | average_trial_covariance : bool 104 | Returns 105 | ------- 106 | v : 2d array 107 | the sorted spatial filters 108 | a : 2d array 109 | the sorted spatial patterns. Column i of a represents the 110 | pattern of the filter in column i of v. 111 | d : 1d array 112 | the variances of the components 113 | Raises 114 | ------ 115 | AssertionError : 116 | If: 117 | * ``classes`` is not ``None`` and has less than two elements 118 | * ``classes`` is not ``None`` and the first two elements are 119 | not found in the ``epo`` 120 | * ``classes`` is ``None`` but there are less than two 121 | different classes in the ``epo`` 122 | See Also 123 | -------- 124 | :func:`apply_spatial_filter`, :func:`apply_csp`, :func:`calculate_spoc` 125 | References 126 | ---------- 127 | http://en.wikipedia.org/wiki/Common_spatial_pattern 128 | """ 129 | if classes is None: 130 | # automagically find the first two different classidx 131 | # we don't use uniq, since it sorts the classidx first 132 | # first check if we have a least two diffeent idxs: 133 | unique_classes = np.unique(epo.y) 134 | assert len(unique_classes) == 2 135 | cidx1 = unique_classes[0] 136 | cidx2 = unique_classes[1] 137 | else: 138 | assert (len(classes) == 2 and 139 | classes[0] in epo.y and 140 | classes[1] in epo.y) 141 | cidx1 = classes[0] 142 | cidx2 = classes[1] 143 | epo1 = select_classes(epo, [cidx1]) 144 | epo2 = select_classes(epo, [cidx2]) 145 | if average_trial_covariance: 146 | # computing c1 as mean covariance of trial covariances: 147 | c1 = np.mean([np.cov(x) for x in epo1.X], axis=0) 148 | c2 = np.mean([np.cov(x) for x in epo2.X], axis=0) 149 | else: 150 | # we need a matrix of the form (channels, observations) so we stack 151 | # trials and time per channel together 152 | x1 = np.concatenate(epo1.X, axis=1) 153 | x2 = np.concatenate(epo2.X, axis=1) 154 | # compute covariance matrices of the two classes 155 | c1 = np.cov(x1) 156 | c2 = np.cov(x2) 157 | # solution of csp objective via generalized eigenvalue problem 158 | # in matlab the signature is v, d = eig(a, b) 159 | 160 | # d1, v1 = sp.linalg.eigh(c2, c1 + c2) # old instruction 161 | d, v = sp.linalg.eig(c2, c1 + c2) 162 | d = d.real 163 | # make sure the eigenvalues and -vectors are correctly sorted 164 | indx = np.argsort(d) 165 | # reverse 166 | indx = indx[::-1] 167 | d = d.take(indx) 168 | v = v.take(indx, axis=1) 169 | 170 | # Now compute patterns 171 | # old pattern computation 172 | # a = sp.linalg.inv(v).transpose() 173 | c_avg = (c1 + c2) / 2.0 174 | 175 | # compare 176 | # https://github.com/bbci/bbci_public/blob/ 177 | # c7201e4e42f873cced2e068c6cbb3780a8f8e9ec/processing/proc_csp.m#L112 178 | # with W := v 179 | v_with_cov = np.dot(c_avg, v) 180 | source_cov = np.dot(np.dot(v.T, c_avg), v) 181 | # matlab-python comparison 182 | """ 183 | v_with_cov = np.array([[1,2,-2], 184 | [3,-2,4], 185 | [5,1,0.3]]) 186 | 187 | source_cov = np.array([[1,2,0.5], 188 | [2,0.6,4], 189 | [0.5,4,2]]) 190 | 191 | sp.linalg.solve(source_cov.T, v_with_cov.T).T 192 | # for matlab 193 | v_with_cov = [[1,2,-2], 194 | [3,-2,4], 195 | [5,1,0.3]] 196 | 197 | source_cov = [[1,2,0.5], 198 | [2,0.6,4], 199 | [0.5,4,2]] 200 | v_with_cov / source_cov""" 201 | 202 | a = sp.linalg.solve(source_cov.T, v_with_cov.T).T 203 | return v, a, d 204 | 205 | 206 | def apply_csp_fast(epo, filt, columns=[0, -1]): 207 | """Apply the CSP filter. 208 | 209 | Apply the spacial CSP filter to the epoched data. 210 | 211 | Parameters 212 | ---------- 213 | epo : epoched ``Data`` object 214 | this method relies on the ``epo`` to have three dimensions in 215 | the following order: class, time, channel 216 | filt : 2d array 217 | the CSP filter (i.e. the ``v`` return value from 218 | :func:`calculate_csp`) 219 | columns : array of ints, optional 220 | the columns of the filter to use. The default is the first and 221 | the last one. 222 | 223 | Returns 224 | ------- 225 | epo : epoched ``Data`` object 226 | The channels from the original have been replaced with the new 227 | virtual CSP channels. 228 | 229 | Examples 230 | -------- 231 | 232 | >> w, a, d = calculate_csp(epo) 233 | >> epo = apply_csp_fast(epo, w) 234 | 235 | See Also 236 | -------- 237 | :func:`calculate_csp` 238 | :func:`apply_csp` 239 | 240 | """ 241 | # getting only selected columns 242 | f = filt[:, columns] 243 | 244 | # pre-allocating filtered vector 245 | filtered = [] 246 | 247 | # TODO: pre-allocate in the right way, no append please 248 | 249 | # filtering on each trial 250 | for trial_i in range(len(epo.X)): 251 | # time x filters 252 | this_filtered = np.dot(epo.X[trial_i].T, f) 253 | # to filters x time 254 | filtered.append(this_filtered.T) 255 | return SignalAndTarget(filtered, epo.y) 256 | 257 | 258 | def apply_csp_var_log(epo, filters, columns): 259 | csp_filtered = apply_csp_fast(epo, filters, columns) 260 | # 1 is t 261 | csp_filtered.X = np.array([ 262 | np.log(np.var(trial, axis=1)) for trial in csp_filtered.X 263 | ]) 264 | return csp_filtered 265 | -------------------------------------------------------------------------------- /hgdecode/utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | import datetime 4 | import numpy as np 5 | import logging as log 6 | from os import getcwd 7 | from os import system 8 | from os import listdir 9 | from os import makedirs 10 | from os import remove 11 | from sys import platform 12 | from pickle import dump 13 | from os.path import join 14 | from os.path import exists 15 | from os.path import dirname 16 | from sklearn.metrics import confusion_matrix 17 | 18 | now_dir = '' 19 | 20 | 21 | def dash_printer(input_string='', manual_input=32): 22 | sys.stdout.write('-' * len(input_string) + '-' * manual_input + '\n') 23 | 24 | 25 | def print_manager(input_string='', 26 | print_style='normal', 27 | top_return=None, 28 | bottom_return=None): 29 | # top return 30 | if top_return is not None: 31 | sys.stdout.write('\n' * top_return) 32 | 33 | # input_string 34 | if print_style == 'normal': 35 | log.info(input_string) 36 | elif print_style == 'top-dashed': 37 | dash_printer(input_string) 38 | log.info(input_string) 39 | print_manager.last_print = input_string 40 | elif print_style == 'bottom-dashed': 41 | log.info(input_string) 42 | dash_printer(input_string) 43 | print_manager.last_print = input_string 44 | elif print_style == 'double-dashed': 45 | dash_printer(input_string) 46 | log.info(input_string) 47 | dash_printer(input_string) 48 | print_manager.last_print = input_string 49 | elif print_style == 'last': 50 | log.info(input_string) 51 | dash_printer(print_manager.last_print) 52 | print_manager.last_print = input_string 53 | 54 | # bottom return 55 | if bottom_return is not None: 56 | sys.stdout.write('\n' * bottom_return) 57 | 58 | 59 | def datetime_dir_format(): 60 | now = datetime.datetime.now() 61 | second = str(now.second) 62 | minute = str(now.minute) 63 | hour = str(now.hour) 64 | day = str(now.day) 65 | month = str(now.month) 66 | year = str(now.year) 67 | if len(second) == 1: 68 | second = '0' + second 69 | if len(minute) == 1: 70 | minute = '0' + minute 71 | if len(hour) == 1: 72 | hour = '0' + hour 73 | if len(day) == 1: 74 | day = '0' + day 75 | if len(month) == 1: 76 | month = '0' + month 77 | return '_'.join(['-'.join([year, month, day]), 78 | '-'.join([hour, minute, second])]) 79 | 80 | 81 | def touch_dir(directory): 82 | if exists(directory): 83 | return True 84 | else: 85 | makedirs(directory) 86 | return False 87 | 88 | 89 | def touch_file(file_path): 90 | if exists(file_path): 91 | return True 92 | else: 93 | return False 94 | 95 | 96 | def listdir2(path): 97 | l = listdir(path) 98 | idx = 0 99 | while idx < len(l): 100 | if l[idx][0] is '.': 101 | l.remove(l[idx]) 102 | else: 103 | idx += 1 104 | return l 105 | 106 | 107 | def create_log(results_dir, 108 | learning_type='ml', 109 | algorithm_or_model_name='FBCSP_rLDA', 110 | subject_id=1, 111 | output_on_file=False, 112 | use_last_result_directory=False): 113 | # getting now_dir from global 114 | global now_dir 115 | 116 | # setting temporary results directory 117 | results_dir = join(results_dir, 118 | learning_type, 119 | algorithm_or_model_name) 120 | 121 | # setting now_dir if necessary 122 | if len(now_dir) is 0: 123 | if use_last_result_directory is True: 124 | dirs_in_folder = listdir(results_dir) 125 | dirs_in_folder.sort() 126 | now_dir = dirs_in_folder[-1] 127 | else: 128 | now_dir = datetime_dir_format() 129 | 130 | # setting log_file_dir 131 | log_file_dir = join(results_dir, now_dir) 132 | 133 | # setting subject_id_str 134 | if type(subject_id) is str: 135 | subject_str = subject_id 136 | else: 137 | subject_str = str(subject_id) 138 | if len(subject_str) == 1: 139 | subject_str = '0' + subject_str 140 | subject_str = 'subj' + subject_str 141 | 142 | # setting subject_results_dir 143 | subject_results_dir = join(log_file_dir, subject_str) 144 | 145 | # touching directories 146 | touch_dir(log_file_dir) 147 | touch_dir(subject_results_dir) 148 | 149 | if output_on_file is True: 150 | # setting log_file_name 151 | log_file_name = 'log.bin' 152 | 153 | # setting log_file_path 154 | log_file_path = join(log_file_dir, log_file_name) 155 | 156 | # creating the log file 157 | sys.stdout = open(log_file_path, 'w') 158 | 159 | # opening it using system commands 160 | if platform == 'linux': 161 | system('xdg-open ' + log_file_path.replace(' ', '\ ')) 162 | elif platform == 'darwin': # macOSX 163 | system('open ' + log_file_path.replace(' ', '\ ')) 164 | else: 165 | sys.stdout.write('platform {:s} still not supported'.format( 166 | platform)) 167 | 168 | # setting the logging object configuration 169 | log.basicConfig( 170 | format='%(asctime)s | %(levelname)s: %(message)s', 171 | filemode='w', 172 | stream=sys.stdout, 173 | level=log.DEBUG 174 | ) 175 | else: 176 | # setting the logging object configuration 177 | log.basicConfig( 178 | format='%(asctime)s | %(levelname)s: %(message)s', 179 | level=log.DEBUG, 180 | stream=sys.stdout 181 | ) 182 | 183 | # printing current cycle information 184 | print_manager( 185 | '{} with {} on subject {}'.format( 186 | learning_type.upper(), 187 | algorithm_or_model_name, 188 | subject_id 189 | ), 190 | 'double-dashed', 191 | bottom_return=1 192 | ) 193 | 194 | # returning subject_results_dir, in some case it can be helpful 195 | return subject_results_dir 196 | 197 | 198 | def ml_results_saver(exp, subj_results_dir): 199 | for fold_idx in range(exp.n_folds): 200 | # computing paths and directories 201 | fold_str = str(fold_idx + 1) 202 | if len(fold_str) is not 2: 203 | fold_str = '0' + fold_str 204 | fold_str = 'fold' + fold_str 205 | fold_dir = join(subj_results_dir, fold_str) 206 | touch_dir(fold_dir) 207 | file_path = join(fold_dir, 'fold_stats.pickle') 208 | 209 | # getting accuracies 210 | train_acc = exp.multi_class.train_accuracy[fold_idx] 211 | test_acc = exp.multi_class.test_accuracy[fold_idx] 212 | 213 | # getting y_true and y_pred for this fold 214 | train_true = exp.multi_class.train_labels[fold_idx] 215 | train_pred = exp.multi_class.train_predicted_labels[fold_idx] 216 | test_true = exp.multi_class.test_labels[fold_idx] 217 | test_pred = exp.multi_class.test_predicted_labels[fold_idx] 218 | 219 | # computing confusion matrices 220 | train_conf_mtx = confusion_matrix(train_true, train_pred) 221 | test_conf_mtx = confusion_matrix(test_true, test_pred) 222 | 223 | # creating results dictionary 224 | results = { 225 | 'train': { 226 | 'acc': train_acc, 227 | 'conf_mtx': train_conf_mtx.tolist() 228 | }, 229 | 'test': { 230 | 'acc': test_acc, 231 | 'conf_mtx': test_conf_mtx.tolist() 232 | } 233 | } 234 | 235 | # saving results 236 | with open(file_path, 'wb') as f: 237 | dump(results, f) 238 | 239 | 240 | def csv_manager(csv_path, line): 241 | if not exists(csv_path): 242 | with open(csv_path, 'w', newline='') as f: 243 | writer = csv.writer(f, delimiter=',') 244 | writer.writerows([line]) 245 | else: 246 | with open(csv_path, 'a', newline='') as f: 247 | writer = csv.writer(f) 248 | writer.writerows([line]) 249 | 250 | 251 | def get_metrics_from_conf_mtx(conf_mtx, label_names=None): 252 | # creating standard label_names if not specified 253 | if label_names is None: 254 | label_names = ['label ' + str(x) for x in range(len(conf_mtx))] 255 | 256 | # computing true/false positive/negative from confusion matrix 257 | TP = np.diag(conf_mtx) 258 | FP = conf_mtx.sum(axis=0) - TP 259 | FN = conf_mtx.sum(axis=1) - TP 260 | TN = conf_mtx.sum() - (FP + FN + TP) 261 | 262 | # parsing to float 263 | TP = TP.astype(float) 264 | TN = TN.astype(float) 265 | FP = FP.astype(float) 266 | FN = FN.astype(float) 267 | 268 | # computing true positive rate (sensitivity, hit rate, recall) 269 | TPR = TP / (TP + FN) 270 | 271 | # computing true negative rate (specificity) 272 | TNR = TN / (TN + FP) 273 | 274 | # computing positive predictive value (precision) 275 | PPV = TP / (TP + FP) 276 | 277 | # computing negative predictive value 278 | NPV = TN / (TN + FN) 279 | 280 | # computing false positive rate (fall out) 281 | FPR = FP / (FP + TN) 282 | 283 | # computing false negative rate 284 | FNR = FN / (TP + FN) 285 | 286 | # computing false discovery rate 287 | FDR = FP / (TP + FP) 288 | 289 | # computing f1-score 290 | F1 = 2 * TP / (2 * TP + FP + FN) 291 | 292 | # computing accuracy on single label 293 | ACC = (TP + TN) / (TP + FP + FN + TN) 294 | 295 | # computing overall accuracy 296 | acc = TP.sum() / conf_mtx.sum() 297 | 298 | # pre-allocating metrics_report 299 | metrics_report = {x: None for x in label_names} 300 | metrics_report['acc'] = acc 301 | 302 | # filling metrics_report 303 | for idx, label in enumerate(label_names): 304 | metrics_report[label] = { 305 | 'TP': TP[idx], 306 | 'TN': TN[idx], 307 | 'FP': FP[idx], 308 | 'FN': FN[idx], 309 | 'TPR': TPR[idx], 310 | 'TNR': TNR[idx], 311 | 'PPV': PPV[idx], 312 | 'NPV': NPV[idx], 313 | 'FPR': FPR[idx], 314 | 'FNR': FNR[idx], 315 | 'FDR': FDR[idx], 316 | 'F1': F1[idx], 317 | 'ACC': ACC[idx] 318 | } 319 | 320 | # returning metrics_report 321 | return metrics_report 322 | 323 | 324 | def check_significant_digits(num): 325 | num = float(num) 326 | if num < 0: 327 | negative_flag = True 328 | num = -num 329 | else: 330 | negative_flag = False 331 | if num < 0.01: # from 0.009999 332 | num = np.round(num, 5) 333 | elif num < 0.1: # from 0.09999 334 | num = np.round(num, 4) 335 | else: 336 | num = np.round(num, 3) 337 | num = num * 100 338 | if num == 0: 339 | num = '0' 340 | elif num < 1: 341 | num = np.round(num, 3) 342 | num = str(num) 343 | num += '0' * (5 - len(num)) 344 | num = num[0:4] 345 | elif num < 10: 346 | num = np.round(num, 2) 347 | num = str(num) 348 | num += '0' * (4 - len(num)) 349 | num = num[0:4] 350 | elif num == 100: 351 | num = '100' 352 | else: 353 | num = np.round(num, 1) 354 | num = str(num) 355 | if negative_flag: 356 | num = '-' + num 357 | return num 358 | 359 | 360 | def get_subj_str(subj_id): 361 | return my_formatter(subj_id, 'subj') 362 | 363 | 364 | def get_fold_str(fold_id): 365 | return my_formatter(fold_id, 'fold') 366 | 367 | 368 | def my_formatter(num, name): 369 | num_str = str(num) 370 | if len(num_str) == 1: 371 | num_str = '0' + num_str 372 | return name + num_str 373 | 374 | 375 | def get_path(results_dir=None, 376 | learning_type='dl', 377 | algorithm_or_model_name=None, 378 | epoching=(-500, 4000), 379 | fold_type='single_subject', 380 | n_folds=2, 381 | deprecated=False, 382 | balanced_folds=True): 383 | # checking results_dir 384 | if results_dir is None: 385 | results_dir = join(dirname(dirname(dirname(getcwd()))), 'results') 386 | 387 | # checking algorithm_or_model_name 388 | if algorithm_or_model_name is None: 389 | if learning_type == 'ml': 390 | algorithm_or_model_name = 'FBCSP_rLDA' 391 | elif learning_type == 'dl': 392 | algorithm_or_model_name = 'DeepConvNet' 393 | else: 394 | raise ValueError( 395 | 'Invalid learning_type inputed: {}'.format(learning_type) 396 | ) 397 | 398 | # checking epoching 399 | if epoching.__class__ is tuple or epoching.__class__ is list: 400 | epoching_str = str(epoching[0]) + '_' + str(epoching[1]) 401 | elif epoching.__class__ is str: 402 | epoching_str = epoching 403 | else: 404 | raise ValueError( 405 | 'Invalid epoching type: {}'.format(epoching.__class__) 406 | ) 407 | 408 | # checking fold_type 409 | folder = '' 410 | if fold_type == 'schirrmeister': 411 | folder += '1' 412 | elif fold_type == 'single_subject': 413 | if epoching_str == '-500_4000': 414 | folder += '2' 415 | elif epoching_str == '-1000_1000': 416 | folder += '3' 417 | elif epoching_str == '-1500_500': 418 | folder += '4' 419 | elif fold_type == 'cross_subject': 420 | if epoching_str == '-500_4000': 421 | folder += '5' 422 | elif epoching_str == '-1000_1000': 423 | folder += '6' 424 | elif fold_type == 'transfer_learning': 425 | if epoching_str == '-500_4000': 426 | folder += '7' 427 | else: 428 | folder += '8' 429 | elif fold_type == 'transfer_learning_frozen': 430 | folder += '9' 431 | else: 432 | raise ValueError( 433 | 'Invalid fold_type: {}'.format(fold_type) 434 | ) 435 | folder += '_' + fold_type + '_' + epoching_str 436 | 437 | # checking for deprecated / not stratified 438 | if deprecated is True: 439 | if balanced_folds is True: 440 | folder = join('0_deprecated', folder) 441 | else: 442 | folder = join('0_deprecated', '#_not_stratified', folder) 443 | 444 | # building folder path 445 | folder_path = join(results_dir, 446 | 'hgdecode', 447 | learning_type, 448 | algorithm_or_model_name, 449 | folder) 450 | 451 | if (fold_type == 'single_subject') and (epoching_str == '-500_4000'): 452 | folder_path = join(folder_path, my_formatter(n_folds, 'fold')) 453 | elif fold_type == 'transfer_learning': 454 | folder_path = join(folder_path, 'train_size_' + str(n_folds)) 455 | elif fold_type == 'transfer_learning_frozen': 456 | folder_path = join(folder_path, 'frozen_' + str(n_folds), 457 | 'train_size_128') 458 | 459 | return join(folder_path, listdir2(folder_path)[0]) 460 | 461 | 462 | def clear_all_models(subj_results_dir): 463 | fold_folders = listdir2(subj_results_dir) 464 | for fold_folder in fold_folders: 465 | remove(join(subj_results_dir, fold_folder, 'net_best_val_loss.h5')) 466 | -------------------------------------------------------------------------------- /ml_main.py: -------------------------------------------------------------------------------- 1 | from os import getcwd 2 | from os.path import join 3 | from os.path import dirname 4 | from collections import OrderedDict 5 | from numpy.random import RandomState 6 | from hgdecode.utils import create_log 7 | from hgdecode.utils import my_formatter 8 | from hgdecode.utils import ml_results_saver 9 | from hgdecode.classes import CrossValidation 10 | from hgdecode.loaders import ml_loader 11 | from hgdecode.experiments import FBCSPrLDAExperiment 12 | 13 | """ 14 | SETTING PARAMETERS 15 | ------------------ 16 | In the following, you have to set / modify all the parameters to use for 17 | further computation. 18 | 19 | Parameters 20 | ---------- 21 | algorithm_name : str 22 | Machine Learning algorithm name that is going to be used 23 | channel_names : list 24 | Channels to use for computation 25 | data_dir : str 26 | Path to the directory that contains dataset 27 | name_to_start_codes : OrderedDict 28 | All possible classes names and codes in an ordered dict format 29 | results_dir : str 30 | 31 | subject_ids : tuple 32 | All the subject ids in a tuple; add or remove subjects to run the 33 | algorithm for them or not 34 | """ 35 | # setting ml_algorithm 36 | algorithm_name = 'FBCSP_rLDA' 37 | 38 | # setting channel_names 39 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 40 | 'CP5', 'CP1', 'CP2', 'CP6', 41 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 42 | 'CP3', 'CPz', 'CP4', 43 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 44 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h', 45 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 46 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h', 47 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 48 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h'] 49 | 50 | # setting data_dir & results_dir 51 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma') 52 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode') 53 | 54 | # setting name_to_start_codes 55 | name_to_start_codes = OrderedDict([('Right Hand', [1]), 56 | ('Left Hand', [2]), 57 | ('Rest', [3]), 58 | ('Feet', [4])]) 59 | 60 | # setting random state 61 | random_state = RandomState(1234) 62 | 63 | # real useful hyperparameters 64 | standardize_mode = 0 65 | subject_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14) 66 | ival = (-500, 4000) 67 | n_folds = 6 68 | 69 | # fold stuff 70 | ival_str = str(ival[0]) + '_' + str(ival[1]) 71 | fold_dir = join(data_dir, 72 | 'stratified_fold_' + ival_str, 73 | my_formatter(n_folds, 'fold')) 74 | 75 | """ 76 | MAIN CYCLE 77 | ---------- 78 | For each subject, a new log will be created and the specific dataset loaded; 79 | this dataset will be used to create an instance of the experiment; then the 80 | experiment will be run. You can of course change all the experiment inputs 81 | to obtain different results. 82 | """ 83 | for subject_id in subject_ids: 84 | # creating a log object 85 | subj_results_dir = create_log( 86 | results_dir=results_dir, 87 | learning_type='ml', 88 | algorithm_or_model_name=algorithm_name, 89 | subject_id=subject_id, 90 | output_on_file=False 91 | ) 92 | 93 | # loading dataset 94 | cnt, clean_trial_mask = ml_loader( 95 | data_dir=data_dir, 96 | name_to_start_codes=name_to_start_codes, 97 | channel_names=channel_names, 98 | subject_id=subject_id, 99 | resampling_freq=250, # Schirrmeister: 250 100 | clean_ival_ms=ival, # Schirrmeister: (0, 4000) 101 | train_test_split=True, # Schirrmeister: True 102 | clean_on_all_channels=False, # Schirrmeister: True 103 | standardize_mode=standardize_mode # Schirrmeister: 2 104 | ) 105 | 106 | # creating experiment instance 107 | exp = FBCSPrLDAExperiment( 108 | # signal-related inputs 109 | cnt=cnt, 110 | clean_trial_mask=clean_trial_mask, 111 | name_to_start_codes=name_to_start_codes, 112 | random_state=random_state, 113 | name_to_stop_codes=None, # Schirrmeister: None 114 | epoch_ival_ms=ival, # Schirrmeister: (-500, 4000) 115 | 116 | # bank filter-related inputs 117 | min_freq=[0, 10], # Schirrmeister: [0, 10] 118 | max_freq=[12, 122], # Schirrmeister: [12, 122] 119 | window=[6, 8], # Schirrmeister: [6, 8] 120 | overlap=[3, 4], # Schirrmeister: [3, 4] 121 | filt_order=3, # filt_order: 3 122 | 123 | # machine learning parameters 124 | n_folds=n_folds, # Schirrmeister: ? 125 | fold_file=join(fold_dir, my_formatter(subject_id, 'subj') + '.npz'), 126 | n_top_bottom_csp_filters=5, # Schirrmeister: 5 127 | n_selected_filterbands=None, # Schirrmeister: None 128 | n_selected_features=20, # Schirrmeister: 20 129 | forward_steps=2, # Schirrmeister: 2 130 | backward_steps=1, # Schirrmeister: 1 131 | stop_when_no_improvement=False, # Schirrmeister: False 132 | shuffle=False, # Schirrmeister: False 133 | average_trial_covariance=True # Schirrmeister: True 134 | ) 135 | 136 | # running the experiment 137 | exp.run() 138 | 139 | # saving results for this subject 140 | ml_results_saver(exp=exp, subj_results_dir=subj_results_dir) 141 | 142 | # computing statistics for this subject 143 | CrossValidation.cross_validate(subj_results_dir=subj_results_dir, 144 | label_names=name_to_start_codes) 145 | -------------------------------------------------------------------------------- /ml_main_cross_subject.py: -------------------------------------------------------------------------------- 1 | from os import getcwd 2 | from os.path import join 3 | from os.path import dirname 4 | from collections import OrderedDict 5 | from numpy.random import RandomState 6 | from hgdecode.utils import create_log 7 | from hgdecode.utils import ml_results_saver 8 | from hgdecode.loaders import CrossSubject 9 | from hgdecode.classes import CrossValidation 10 | from hgdecode.experiments import FBCSPrLDAExperiment 11 | 12 | """ 13 | SETTING PARAMETERS 14 | Here you can set whatever parameter you want 15 | """ 16 | # setting ml_algorithm 17 | algorithm_name = 'FBCSP_rLDA' 18 | 19 | # setting channel_names 20 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 21 | 'CP5', 'CP1', 'CP2', 'CP6', 22 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 23 | 'CP3', 'CPz', 'CP4', 24 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 25 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h', 26 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 27 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h', 28 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 29 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h'] 30 | 31 | # setting data_dir & results_dir 32 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma') 33 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode') 34 | 35 | # setting name_to_start_codes 36 | name_to_start_codes = OrderedDict([('Right Hand', [1]), 37 | ('Left Hand', [2]), 38 | ('Rest', [3]), 39 | ('Feet', [4])]) 40 | 41 | # setting random state 42 | random_state = RandomState(1234) 43 | 44 | # setting subject_ids 45 | subject_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14) 46 | 47 | # %% 48 | """ 49 | STARTING LOADING ROUTINE & COMPUTATION 50 | Here you can change some parameter in function calls as well 51 | """ 52 | # creating a log object 53 | subj_results_dir = create_log( 54 | results_dir=results_dir, 55 | learning_type='ml', 56 | algorithm_or_model_name=algorithm_name, 57 | subject_id='subj_cross', 58 | output_on_file=False 59 | ) 60 | 61 | # creating a cross-subject object for cross-subject validation 62 | cross_obj = CrossSubject(data_dir=data_dir, 63 | subject_ids=subject_ids, 64 | channel_names=channel_names, 65 | name_to_start_codes=name_to_start_codes, 66 | resampling_freq=250, 67 | train_test_split=True, 68 | clean_ival_ms=(-1000, 1000), 69 | epoch_ival_ms=(-1000, 1000), 70 | clean_on_all_channels=False) 71 | 72 | """ 73 | Si potrebbe fare un soft parsing così trova le fold, poi si passa ad exp 74 | quello che gli serve (cnt all, clean all, fold all) e si butta tutto il 75 | resto... 76 | """ 77 | 78 | # creating experiment instance 79 | exp = FBCSPrLDAExperiment( 80 | # signal-related inputs 81 | cnt=cross_obj.data, 82 | clean_trial_mask=cross_obj.clean_trial_mask, 83 | name_to_start_codes=name_to_start_codes, 84 | random_state=random_state, 85 | name_to_stop_codes=None, # Schirrmeister: None 86 | epoch_ival_ms=(-1000, 1000), # Schirrmeister: (-500, 4000) 87 | cross_subject_object=cross_obj, 88 | 89 | # bank filter-related inputs 90 | min_freq=[0, 10], # Schirrmeister: [0, 10] 91 | max_freq=[12, 122], # Schirrmeister: [12, 122] 92 | window=[6, 8], # Schirrmeister: [6, 8] 93 | overlap=[3, 4], # Schirrmeister: [3, 4] 94 | filt_order=3, # filt_order: 3 95 | 96 | # machine learning parameters 97 | n_folds=14, # Schirrmeister: ? 98 | n_top_bottom_csp_filters=5, # Schirrmeister: 5 99 | n_selected_filterbands=None, # Schirrmeister: None 100 | n_selected_features=20, # Schirrmeister: 20 101 | forward_steps=2, # Schirrmeister: 2 102 | backward_steps=1, # Schirrmeister: 1 103 | stop_when_no_improvement=False, # Schirrmeister: False 104 | shuffle=False, # Schirrmeister: False 105 | average_trial_covariance=True # Schirrmeister: True 106 | ) 107 | 108 | # running the experiment 109 | exp.run() 110 | 111 | # saving results 112 | ml_results_saver(exp=exp, subj_results_dir=subj_results_dir) 113 | 114 | # at the very end, running cross-validation 115 | CrossValidation.cross_validate(subj_results_dir=subj_results_dir, 116 | label_names=name_to_start_codes) 117 | -------------------------------------------------------------------------------- /schirrmeister_main.py: -------------------------------------------------------------------------------- 1 | from os import getcwd 2 | from os.path import join 3 | from os.path import dirname 4 | from collections import OrderedDict 5 | from numpy.random import RandomState 6 | from numpy import array 7 | from numpy import floor 8 | from numpy import setdiff1d 9 | from hgdecode.utils import create_log 10 | from hgdecode.utils import print_manager 11 | from hgdecode.loaders import ml_loader 12 | from hgdecode.classes import CrossValidation 13 | from hgdecode.experiments import DLExperiment 14 | from hgdecode.experiments import FBCSPrLDAExperiment 15 | from keras import backend as K 16 | from braindecode.datautil.trial_segment import \ 17 | create_signal_target_from_raw_mne 18 | from hgdecode.utils import ml_results_saver 19 | 20 | """ 21 | ONLY PARAMETER YOU CAN CHOSE 22 | ---------------------------- 23 | """ 24 | # set here what type of learning you want 25 | learning_type = 'ml' 26 | 27 | """ 28 | SETTING OTHER PARAMETERS (YOU CANNOT MODIFY THAT) 29 | ------------------------------------------------- 30 | """ 31 | # setting channel_names 32 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 33 | 'CP5', 'CP1', 'CP2', 'CP6', 34 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 35 | 'CP3', 'CPz', 'CP4', 36 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 37 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h', 38 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 39 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h', 40 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 41 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h'] 42 | 43 | # setting data_dir & results_dir 44 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma') 45 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode') 46 | 47 | # setting name_to_start_codes 48 | name_to_start_codes = OrderedDict([('Right Hand', [1]), 49 | ('Left Hand', [2]), 50 | ('Rest', [3]), 51 | ('Feet', [4])]) 52 | 53 | # setting random_state 54 | random_state = RandomState(1234) 55 | 56 | # subject list 57 | subject_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14) 58 | 59 | # computing algorithm_or_model_name and standardize_mode 60 | if learning_type == 'ml': 61 | algorithm_or_model_name = 'FBCSP_rLDA' 62 | standardize_mode = 0 63 | else: 64 | algorithm_or_model_name = 'DeepConvNet' 65 | standardize_mode = 2 66 | 67 | """ 68 | MAIN CYCLE 69 | ---------- 70 | """ 71 | for subject_id in subject_ids: 72 | # creating a log object 73 | subj_results_dir = create_log( 74 | results_dir=results_dir, 75 | learning_type=learning_type, 76 | algorithm_or_model_name=algorithm_or_model_name, 77 | subject_id=subject_id, 78 | output_on_file=False 79 | ) 80 | 81 | # loading cnt signal 82 | cnt, clean_trial_mask = ml_loader( 83 | data_dir=data_dir, 84 | name_to_start_codes=name_to_start_codes, 85 | channel_names=channel_names, 86 | subject_id=subject_id, 87 | resampling_freq=250, # Schirrmeister: 250 88 | clean_ival_ms=(-500, 4000), # Schirrmeister: (0, 4000) 89 | train_test_split=True, # Schirrmeister: True 90 | clean_on_all_channels=False, # Schirrmeister: True 91 | standardize_mode=standardize_mode # Schirrmeister: 2 92 | ) 93 | 94 | # splitting two algorithms 95 | if learning_type == 'ml': 96 | # creating experiment instance 97 | exp = FBCSPrLDAExperiment( 98 | # signal-related inputs 99 | cnt=cnt, 100 | clean_trial_mask=clean_trial_mask, 101 | name_to_start_codes=name_to_start_codes, 102 | random_state=random_state, 103 | name_to_stop_codes=None, # Schirrmeister: None 104 | epoch_ival_ms=(-500, 4000), # Schirrmeister: (-500, 4000) 105 | 106 | # bank filter-related inputs 107 | min_freq=[0, 10], # Schirrmeister: [0, 10] 108 | max_freq=[12, 122], # Schirrmeister: [12, 122] 109 | window=[6, 8], # Schirrmeister: [6, 8] 110 | overlap=[3, 4], # Schirrmeister: [3, 4] 111 | filt_order=3, # filt_order: 3 112 | 113 | # machine learning parameters 114 | n_folds=0, # Schirrmeister: ? 115 | n_top_bottom_csp_filters=5, # Schirrmeister: 5 116 | n_selected_filterbands=None, # Schirrmeister: None 117 | n_selected_features=20, # Schirrmeister: 20 118 | forward_steps=2, # Schirrmeister: 2 119 | backward_steps=1, # Schirrmeister: 1 120 | stop_when_no_improvement=False, # Schirrmeister: False 121 | shuffle=False, # Schirrmeister: False 122 | average_trial_covariance=True # Schirrmeister: True 123 | ) 124 | 125 | # running the experiment 126 | exp.run() 127 | 128 | # saving results for this subject 129 | ml_results_saver(exp=exp, subj_results_dir=subj_results_dir) 130 | 131 | # computing statistics for this subject 132 | CrossValidation.cross_validate(subj_results_dir=subj_results_dir, 133 | label_names=name_to_start_codes) 134 | elif learning_type == 'dl': 135 | # creating schirrmeister fold 136 | all_idxs = array(range(len(clean_trial_mask))) 137 | folds = [ 138 | { 139 | 'train': all_idxs[:-160], 140 | 'test': all_idxs[-160:] 141 | } 142 | ] 143 | folds[0]['train'] = folds[0]['train'][clean_trial_mask[:-160]] 144 | folds[0]['test'] = folds[0]['test'][clean_trial_mask[-160:]] 145 | 146 | # adding validation 147 | valid_idxs = array(range(int(floor(len(clean_trial_mask) * 0.1)))) 148 | folds[0]['train'] = setdiff1d(folds[0]['train'], valid_idxs) 149 | folds[0]['valid'] = valid_idxs 150 | 151 | # parsing cnt to epoched data 152 | print_manager('Epoching...') 153 | epo = create_signal_target_from_raw_mne(cnt, 154 | name_to_start_codes, 155 | (-500, 4000)) 156 | print_manager('DONE!!', bottom_return=1) 157 | 158 | # # cleaning epoched signal with mask 159 | # print_manager('cleaning with mask...') 160 | # epo.X = epo.X[clean_trial_mask] 161 | # epo.y = epo.y[clean_trial_mask] 162 | # print_manager('DONE!!', 'last', bottom_return=1) 163 | 164 | # creating cv instance 165 | cv = CrossValidation(X=epo.X, y=epo.y, shuffle=False) 166 | 167 | # creating EEGDataset for current fold 168 | dataset = cv.create_dataset(fold=folds[0]) 169 | 170 | # clearing TF graph (https://github.com/keras-team/keras/issues/3579) 171 | print_manager('CLEARING KERAS BACKEND', print_style='double-dashed') 172 | K.clear_session() 173 | print_manager(print_style='last', bottom_return=1) 174 | 175 | # creating experiment instance 176 | exp = DLExperiment( 177 | # non-default inputs 178 | dataset=dataset, 179 | model_name=algorithm_or_model_name, 180 | results_dir=results_dir, 181 | subj_results_dir=subj_results_dir, 182 | name_to_start_codes=name_to_start_codes, 183 | random_state=random_state, 184 | fold_idx=0, 185 | 186 | # hyperparameters 187 | dropout_rate=0.5, # Schirrmeister: 0.5 188 | learning_rate=1 * 1e-4, # Schirrmeister: ? 189 | batch_size=32, # Schirrmeister: 512 190 | epochs=1000, # Schirrmeister: ? 191 | early_stopping=False, # Schirrmeister: ? 192 | monitor='val_acc', # Schirrmeister: ? 193 | min_delta=0.0001, # Schirrmeister: ? 194 | patience=5, # Schirrmeister: ? 195 | loss='categorical_crossentropy', # Schirrmeister: ad hoc 196 | optimizer='Adam', # Schirrmeister: Adam 197 | shuffle=True, # Schirrmeister: ? 198 | crop_sample_size=None, # Schirrmeister: 1125 199 | crop_step=None, # Schirrmeister: 1 200 | 201 | # other parameters 202 | subject_id=subject_id, 203 | data_generator=False, # Schirrmeister: True 204 | save_model_at_each_epoch=False 205 | ) 206 | 207 | # training 208 | exp.train() 209 | 210 | # computing cross-validation 211 | CrossValidation.cross_validate(subj_results_dir=subj_results_dir, 212 | label_names=name_to_start_codes) 213 | -------------------------------------------------------------------------------- /sub_routines/dl_cross_validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from hgdecode.classes import CrossValidation 4 | 5 | datetime_dir = '/Users/davidemiani/OneDrive - Alma Mater Studiorum ' \ 6 | 'Università di Bologna/TesiMagistrale_DavideMiani/' \ 7 | 'results/hgdecode/dl/DeepConvNet/2019-01-18_13-33-01' 8 | subj_dirs = os.listdir(datetime_dir) 9 | subj_dirs.sort() 10 | subj_dirs.remove('model_report.txt') 11 | subj_dirs.remove('statistics') 12 | subj_dirs = [os.path.join(datetime_dir, x) for x in subj_dirs] 13 | 14 | name_to_start_codes = OrderedDict([('Right Hand', [1]), 15 | ('Left Hand', [2]), 16 | ('Rest', [3]), 17 | ('Feet', [4])]) 18 | 19 | for subj_dir in subj_dirs: 20 | CrossValidation.cross_validate(subj_results_dir=subj_dir, 21 | label_names=name_to_start_codes) 22 | -------------------------------------------------------------------------------- /sub_routines/latex_tabular_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | from csv import reader 3 | from hgdecode.utils import get_path 4 | from hgdecode.utils import check_significant_digits 5 | 6 | """ 7 | SET HERE YOUR PARAMETERS 8 | """ 9 | # to find file parameters 10 | results_dir = None 11 | learning_type = 'dl' 12 | algorithm_or_model_name = None 13 | epoching = '-1500_500' 14 | fold_type = 'single_subject' 15 | n_folds = 12 16 | deprecated = True 17 | balanced_fold = True 18 | 19 | # metrics parameter 20 | label = 'Feet' # Feet, LeftHand, Rest or RightHand 21 | metric_type = 'overall' # label or overall 22 | metric = 'acc' 23 | 24 | """ 25 | GETTING PATHS 26 | """ 27 | # getting folder path 28 | folder_path = get_path( 29 | results_dir=results_dir, 30 | learning_type=learning_type, 31 | algorithm_or_model_name=algorithm_or_model_name, 32 | epoching=epoching, 33 | fold_type=fold_type, 34 | n_folds=n_folds, 35 | deprecated=deprecated, 36 | balanced_folds=balanced_fold 37 | ) 38 | 39 | # getting file_path 40 | file_path = os.path.join(folder_path, 'statistics', 'tables') 41 | if metric_type == 'overall': 42 | file_path = os.path.join(file_path, metric + '.csv') 43 | else: 44 | file_path = os.path.join(file_path, label, metric + '.csv') 45 | 46 | """ 47 | COMPUTATION START HERE 48 | """ 49 | with open(file_path) as f: 50 | csv = list(reader(f)) 51 | 52 | n_folds = len(csv[0]) - 2 53 | columns = ['&\\textbf{' + str(x + 1) + '}\n' for x in range(n_folds)] 54 | 55 | output = '\\begin{table}[H]\n\\footnotesize\n\\centering\n\\begin{tabular}' + \ 56 | '{|c|' + 'c' * n_folds + '|cc|}\n\\hline\n' + \ 57 | '&\multicolumn{' + str(n_folds) + '}{c|}{\\textbf{fold}}& ' + \ 58 | '&\n\\\\\n\\textbf{subj}\n' 59 | for head in columns: 60 | output += head 61 | output += '&\\textbf{mean}\n&\\textbf{std}\n\\\\\n\hline\hline\n' 62 | 63 | # removing header 64 | csv = csv[1:] 65 | total_m = 0 66 | total_s = 0 67 | 68 | for idx, current_row in enumerate(csv): 69 | if idx % 2 is 0: 70 | output += '\\rowcolor[gray]{.9}\n' 71 | else: 72 | output += '\\rowcolor[gray]{.8}\n' 73 | output += '\\textbf{' + str(idx + 1) + '}\n' 74 | for idy, current_col in enumerate(current_row): 75 | output += '&' + check_significant_digits(current_col) + '\n' 76 | output += '\\\\\n' 77 | total_m += float(current_row[-2]) 78 | total_s += float(current_row[-1]) 79 | total_m /= len(csv) 80 | total_s /= len(csv) 81 | total_m = check_significant_digits(total_m) 82 | total_s = check_significant_digits(total_s) 83 | 84 | caption = learning_type + ' ' + metric + ' ' + \ 85 | epoching.replace('_', ',') + ' ' + str(n_folds) + ' fold' 86 | output += '\\hline\n\\multicolumn{' + str(n_folds + 1) + '}' + \ 87 | '{|r|}{\\textbf{media totale}}\n&' + total_m + '\n&' + \ 88 | total_s + '\n\\\\\n\\hline\n\\end{tabular}\n' + \ 89 | '\\caption{' + caption + '}\n\\label{' + caption + '}\n' + \ 90 | '\\end{table}' 91 | print(output) 92 | -------------------------------------------------------------------------------- /sub_routines/latex_tabular_parser_cross_subj.py: -------------------------------------------------------------------------------- 1 | import os 2 | from csv import reader 3 | from hgdecode.utils import check_significant_digits 4 | 5 | results_dir = '/Users/davidemiani/OneDrive - Alma Mater Studiorum ' \ 6 | 'Università di Bologna/TesiMagistrale_DavideMiani/' \ 7 | 'results/hgdecode' 8 | learning_type = 'dl' # dl or ml 9 | algo_or_model_name = 'DeepConvNet' # DeepConvNet or FBCSP_rLDA 10 | datetime = '2019-01-21_11-38-41' 11 | epoch_ival_ms = '-1000, 1000' # str type 12 | tables_dir = os.path.join(results_dir, 13 | learning_type, 14 | algo_or_model_name, 15 | datetime, 16 | 'statistics', 17 | 'tables') 18 | label_names = ['RightHand', 'LeftHand', 'Rest', 'Feet'] 19 | 20 | # pre-allocating csv 21 | csv = [] 22 | 23 | # getting accuracy 24 | with open(os.path.join(tables_dir, 'acc.csv')) as f: 25 | temp = list(reader(f)) 26 | temp = temp[1] 27 | csv.append(temp) 28 | 29 | # getting precision 30 | for label in label_names: 31 | with open(os.path.join(tables_dir, label, 'prec.csv')) as f: 32 | temp = list(reader(f)) 33 | temp = temp[1] 34 | csv.append(temp) 35 | 36 | # getting f1 score 37 | for label in label_names: 38 | with open(os.path.join(tables_dir, label, 'f1.csv')) as f: 39 | temp = list(reader(f)) 40 | temp = temp[1] 41 | csv.append(temp) 42 | 43 | # transposing csv 44 | csv = list(map(list, zip(*csv))) 45 | 46 | # cutting away last two rows (mean and std) 47 | csv_2 = csv[-2:] 48 | csv = csv[0:-2] 49 | 50 | output = '\\begin{table}[H]\n\\footnotesize\n\\centering\n\\begin{tabular}' + \ 51 | '{|c|ccccccccc|}\n\\hline\n' + \ 52 | '\\textbf{fold}\n&' + \ 53 | '\\textbf{acc}\n&' + \ 54 | '\\textbf{pr 1}\n&' + \ 55 | '\\textbf{pr 2}\n&' + \ 56 | '\\textbf{pr 3}\n&' + \ 57 | '\\textbf{pr 4}\n&' + \ 58 | '\\textbf{f1 1}\n&' + \ 59 | '\\textbf{f1 2}\n&' + \ 60 | '\\textbf{f1 3}\n&' + \ 61 | '\\textbf{f1 4}\n' + \ 62 | '\\\\\n\\hline\\hline\n' 63 | 64 | for idx, current_row in enumerate(csv): 65 | if idx % 2 is 0: 66 | output += '\\rowcolor[gray]{.9}\n' 67 | else: 68 | output += '\\rowcolor[gray]{.8}\n' 69 | output += '\\textbf{' + str(idx + 1) + '}\n' 70 | for idy, current_col in enumerate(current_row): 71 | output += '&' + check_significant_digits(current_col) + '\n' 72 | output += '\\\\\n' 73 | output += '\\hline\n' 74 | 75 | for idx, current_row in enumerate(csv_2): 76 | if idx == 0: 77 | output += '\\textbf{mean}\n' 78 | else: 79 | output += '\\textbf{std}\n' 80 | for idy, current_col in enumerate(current_row): 81 | output += '&' + check_significant_digits(current_col) + '\n' 82 | output += '\\\\\n' 83 | output += '\\hline\n' 84 | 85 | caption = learning_type + ' cross-subject validation ' + epoch_ival_ms 86 | output += '\\end{tabular}\n' + \ 87 | '\\caption{' + caption + '}\n\\label{' + caption + '}\n' + \ 88 | '\\end{table}' 89 | print(output) 90 | -------------------------------------------------------------------------------- /sub_routines/latex_tabular_parser_transfer_learning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from csv import reader 4 | from scipy.stats import ttest_rel 5 | from hgdecode.utils import get_path 6 | from hgdecode.utils import check_significant_digits 7 | 8 | """ 9 | SET HERE YOUR PARAMETERS 10 | """ 11 | ival = (-1000, 1000) 12 | train_trials_list = [4, 8, 16, 32, 64, 128] 13 | reference = 1 # 0 for ML cross, 1 for DL cross, 2 for TL4 ecc. 14 | p_flag = False # if true, it will print p value too. 15 | 16 | """ 17 | GETTING PATHS 18 | """ 19 | folder_paths = [ 20 | get_path( 21 | results_dir=None, 22 | learning_type=x, 23 | algorithm_or_model_name=y, 24 | epoching=ival, 25 | fold_type='cross_subject', 26 | n_folds=None, 27 | deprecated=False 28 | ) 29 | for x, y in zip(['ml', 'dl'], ['FBCSP_rLDA', 'DeepConvNet']) 30 | ] 31 | 32 | folder_paths += [ 33 | get_path( 34 | results_dir=None, 35 | learning_type='dl', 36 | algorithm_or_model_name='DeepConvNet', 37 | epoching=ival, 38 | fold_type='transfer_learning', 39 | n_folds=x, 40 | deprecated=False 41 | ) 42 | for x in train_trials_list 43 | ] 44 | 45 | # getting file_path 46 | csv_paths = [ 47 | os.path.join(x, 'statistics', 'tables', 'acc.csv') 48 | for x in folder_paths 49 | ] 50 | 51 | """ 52 | GETTING DATA FROM CSV FILES 53 | """ 54 | subj_data = [] 55 | mean_data = [] 56 | stdd_data = [] 57 | for idx, csv_path in enumerate(csv_paths): 58 | with open(csv_path) as f: 59 | csv = list(reader(f)) 60 | csv = csv[1:] 61 | csv = [ 62 | list(map(float, csv[x])) 63 | for x in range(len(csv)) 64 | ] 65 | if idx < 2: 66 | subj_data.append(csv[0][:-2]) 67 | mean_data.append(csv[0][-2]) 68 | stdd_data.append(csv[0][-1]) 69 | else: 70 | temp_data = [] 71 | for csv_line in csv: 72 | temp_data.append(csv_line[-2]) 73 | subj_data.append(temp_data) 74 | mean_data.append(np.mean(temp_data)) 75 | stdd_data.append(np.std(temp_data)) 76 | 77 | """ 78 | COMPUTING PERC AND PVAL 79 | """ 80 | perc_data = [] 81 | pval_data = [] 82 | for idx in range(len(mean_data)): 83 | if idx is reference: 84 | perc_data.append(0.) 85 | pval_data.append(float('nan')) 86 | else: 87 | perc_data.append((mean_data[idx] - mean_data[reference]) / 88 | mean_data[reference]) 89 | pval_data.append(ttest_rel(subj_data[idx], subj_data[reference])[1]) 90 | 91 | """ 92 | GENERAL FORMATTING 93 | """ 94 | n_subjs = len(subj_data[0]) 95 | columns = [['subj'] + list(map(str, range(1, n_subjs + 1))) + 96 | ['mean', 'std', '$\\Delta_{\\textbf{\\%}}$', '$p$']] 97 | header = ['ML', 'DL', '4', '8', '16', '32', '64', '128'] 98 | for idx, head in enumerate(header): 99 | temp = [check_significant_digits(str(subj_data[idx][x])) 100 | for x in range(n_subjs)] 101 | columns.append( 102 | [head] + 103 | temp + 104 | [str(check_significant_digits(mean_data[idx]))] + 105 | [str(check_significant_digits(stdd_data[idx]))] + 106 | [str(check_significant_digits(perc_data[idx]))] + 107 | [str(check_significant_digits(pval_data[idx]))] 108 | ) 109 | rows = list(map(list, zip(*columns))) 110 | if p_flag is False: 111 | rows.pop() 112 | 113 | """ 114 | CREATING LATEX TABULAR CODE 115 | """ 116 | # pre-allocating output 117 | output = '' 118 | 119 | # opening table 120 | output += '\\begin{table}[H]\n\\footnotesize\n\\centering\n' 121 | output += '\\begin{tabular}{|c|M{1.4cm}M{1.4cm}|' 122 | output += 'c' * len(train_trials_list) + '|}\n' 123 | output += '\\hline\n&\multicolumn{2}{c|}{\\textbf{cross-soggetto}}\n' 124 | output += '&\multicolumn{' + str(len(train_trials_list)) + '}{c|}' 125 | output += '{\\textbf{transfer learning}}\n\\\\\n' 126 | 127 | # first row is an header 128 | for idx, col in enumerate(rows[0]): 129 | if idx == 0: 130 | output += '\\textbf{' + col + '}\n' 131 | else: 132 | output += '&\\textbf{' + col + '}\n' 133 | output += '\\\\\n\\hline\n\\hline\n' 134 | 135 | # creating iterator and jumping the first element (header) 136 | iterator = iter(rows) 137 | next(iterator) 138 | 139 | for idx, row in enumerate(iterator): 140 | if idx % 2 == 0: 141 | output += '\\rowcolor[gray]{.9}\n' 142 | else: 143 | output += '\\rowcolor[gray]{.8}\n' 144 | for idy, col in enumerate(row): 145 | if idy == 0: 146 | output += '\\textbf{' + col + '}\n' 147 | else: 148 | output += '&' + col + '\n' 149 | output += '\\\\\n' 150 | if idx == n_subjs - 1: 151 | output += '\\hline\n\\hline\n' 152 | 153 | output += '\\hline\n\\end{tabular}' 154 | output += '\n\\caption{tl table}\n\\label{tl table}\n' 155 | output += '\\end{table}' 156 | print(output) 157 | -------------------------------------------------------------------------------- /sub_routines/latex_tabular_parser_transfer_learning_frozen_layers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from csv import reader 4 | from scipy.stats import ttest_rel 5 | from hgdecode.utils import get_path 6 | from hgdecode.utils import check_significant_digits 7 | 8 | """ 9 | SET HERE YOUR PARAMETERS 10 | """ 11 | ival = (-500, 4000) 12 | frozen_layers_list = [1, 2, 3, 4, 5, 6, -5, -4, -3, -2, -1] 13 | reference = 0 # 0 for ML cross, 1 for DL cross, 2 for TL4 ecc. 14 | p_flag = False # if true, it will print p value too. 15 | 16 | """ 17 | GETTING PATHS 18 | """ 19 | folder_paths = [get_path( 20 | results_dir=None, 21 | learning_type='dl', 22 | algorithm_or_model_name=None, 23 | epoching=ival, 24 | fold_type='cross_subject', 25 | n_folds=None, 26 | deprecated=False 27 | )] 28 | 29 | folder_paths += [get_path( 30 | results_dir=None, 31 | learning_type='dl', 32 | algorithm_or_model_name=None, 33 | epoching=ival, 34 | fold_type='transfer_learning', 35 | n_folds=128, 36 | deprecated=False 37 | )] 38 | 39 | folder_paths += [ 40 | get_path( 41 | results_dir=None, 42 | learning_type='dl', 43 | algorithm_or_model_name='DeepConvNet', 44 | epoching=ival, 45 | fold_type='transfer_learning_frozen', 46 | n_folds=x, 47 | deprecated=False 48 | ) 49 | for x in frozen_layers_list 50 | ] 51 | 52 | # getting file_path 53 | csv_paths = [ 54 | os.path.join(x, 'statistics', 'tables', 'acc.csv') 55 | for x in folder_paths 56 | ] 57 | 58 | """ 59 | GETTING DATA FROM CSV FILES 60 | """ 61 | subj_data = [] 62 | mean_data = [] 63 | stdd_data = [] 64 | for idx, csv_path in enumerate(csv_paths): 65 | with open(csv_path) as f: 66 | csv = list(reader(f)) 67 | csv = csv[1:] 68 | csv = [ 69 | list(map(float, csv[x])) 70 | for x in range(len(csv)) 71 | ] 72 | if idx < 1: 73 | subj_data.append(csv[0][:-2]) 74 | mean_data.append(csv[0][-2]) 75 | stdd_data.append(csv[0][-1]) 76 | else: 77 | temp_data = [] 78 | for csv_line in csv: 79 | temp_data.append(csv_line[-2]) 80 | subj_data.append(temp_data) 81 | mean_data.append(np.mean(temp_data)) 82 | stdd_data.append(np.std(temp_data)) 83 | 84 | """ 85 | COMPUTING PERC AND PVAL 86 | """ 87 | perc_data = [] 88 | pval_data = [] 89 | for idx in range(len(mean_data)): 90 | if idx is reference: 91 | perc_data.append(0.) 92 | pval_data.append(float('nan')) 93 | else: 94 | perc_data.append((mean_data[idx] - mean_data[reference]) / 95 | mean_data[reference]) 96 | pval_data.append(ttest_rel(subj_data[idx], subj_data[reference])[1]) 97 | 98 | """ 99 | GENERAL FORMATTING 100 | """ 101 | n_subjs = len(subj_data[0]) 102 | columns = [['subj'] + list(map(str, range(1, n_subjs + 1))) + 103 | ['mean', 'std', '$\\Delta_{\\textbf{\\%}}$', '$p$']] 104 | header = ['CL', '0'] + list(map(str, frozen_layers_list)) 105 | for idx, head in enumerate(header): 106 | temp = [check_significant_digits(str(subj_data[idx][x])) 107 | for x in range(n_subjs)] 108 | columns.append( 109 | [head] + 110 | temp + 111 | [str(check_significant_digits(mean_data[idx]))] + 112 | [str(check_significant_digits(stdd_data[idx]))] + 113 | [str(check_significant_digits(perc_data[idx]))] + 114 | [str(check_significant_digits(pval_data[idx]))] 115 | ) 116 | rows = list(map(list, zip(*columns))) 117 | if p_flag is False: 118 | rows.pop() 119 | 120 | """ 121 | CREATING LATEX TABULAR CODE 122 | """ 123 | # pre-allocating output 124 | output = '' 125 | 126 | # opening table 127 | output += '\\begin{table}[H]\n\\footnotesize\n\\centering\n' 128 | output += '\\begin{tabular}{|c|cc|' 129 | output += 'c' * len(frozen_layers_list) + '|}\n' 130 | output += '\\hline\n&&\n' 131 | output += '&\multicolumn{' + str(len(frozen_layers_list)) + '}{c|}' 132 | output += '{\\textbf{transfer learning con strati congelati}}\n\\\\\n' 133 | 134 | # first row is an header 135 | for idx, col in enumerate(rows[0]): 136 | if idx == 0: 137 | output += '\\textbf{' + col + '}\n' 138 | else: 139 | output += '&\\textbf{' + col + '}\n' 140 | output += '\\\\\n\\hline\n\\hline\n' 141 | 142 | # creating iterator and jumping the first element (header) 143 | iterator = iter(rows) 144 | next(iterator) 145 | 146 | for idx, row in enumerate(iterator): 147 | if idx % 2 == 0: 148 | output += '\\rowcolor[gray]{.9}\n' 149 | else: 150 | output += '\\rowcolor[gray]{.8}\n' 151 | for idy, col in enumerate(row): 152 | if idy == 0: 153 | output += '\\textbf{' + col + '}\n' 154 | else: 155 | output += '&' + col + '\n' 156 | output += '\\\\\n' 157 | if idx == n_subjs - 1: 158 | output += '\\hline\n\\hline\n' 159 | 160 | output += '\\hline\n\\end{tabular}' 161 | output += '\n\\caption{tl table}\n\\label{tl table}\n' 162 | output += '\\end{table}' 163 | print(output) 164 | -------------------------------------------------------------------------------- /sub_routines/learning_curve.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from pylab import savefig 4 | from pickle import load 5 | from os.path import join 6 | from os.path import dirname 7 | from hgdecode.utils import touch_dir 8 | from hgdecode.utils import get_subj_str 9 | from hgdecode.utils import get_fold_str 10 | from hgdecode.utils import get_path 11 | 12 | """ 13 | SET HERE YOUR PARAMETERS 14 | """ 15 | # to find file parameters 16 | results_dir = None 17 | learning_type = 'dl' 18 | algorithm_or_model_name = None 19 | epoching = '-500_4000' 20 | fold_type = 'single_subject' 21 | n_folds_list = [2, 4, 6, 8, 10, 12] # must be a list of integer 22 | deprecated = False 23 | 24 | fontsize_1 = 35 25 | fontsize_2 = 27.5 26 | fig_size = (22, 7.5) 27 | 28 | """ 29 | GETTING PATHS 30 | """ 31 | # trainings stuff 32 | folder_paths = [ 33 | get_path( 34 | results_dir=results_dir, 35 | learning_type=learning_type, 36 | algorithm_or_model_name=algorithm_or_model_name, 37 | epoching=epoching, 38 | fold_type=fold_type, 39 | n_folds=x, 40 | deprecated=deprecated 41 | ) 42 | for x in n_folds_list 43 | ] 44 | n_trainings = len(folder_paths) 45 | 46 | # saving stuff 47 | savings_dir = join(dirname(dirname(folder_paths[0])), 'learning_curve') 48 | touch_dir(savings_dir) 49 | 50 | """ 51 | SUBJECTS STUFF 52 | """ 53 | # subject stuff 54 | n_trials_list = [480, 973, 1040, 1057, 880, 1040, 1040, 814, 1040, 1040, 55 | 1040, 1040, 950, 1040] 56 | n_subjects = len(n_trials_list) 57 | subj_str_list = [get_subj_str(x + 1) for x in range(n_subjects)] 58 | 59 | """ 60 | COMPUTATION STARTS HERE 61 | """ 62 | # pre-allocating results dictionary 63 | results = { 64 | 'n_folds': [], 65 | 'n_trials': [], 66 | 'n_train_trials': [], 67 | 'n_valid_trials': [], 68 | 'n_test_trials': [], 69 | 'perc_train_trials': [], 70 | 'perc_valid_trials': [], 71 | 'perc_test_trials': [], 72 | 'm_acc': [], 73 | 's_acc': [] 74 | } 75 | 76 | # cycling on subject 77 | for subj, current_n_trials in zip(subj_str_list, n_trials_list): 78 | n_folds = [] 79 | n_trials = [] 80 | n_train_trials = [] 81 | n_valid_trials = [] 82 | n_test_trials = [] 83 | perc_train_trials = [] 84 | perc_valid_trials = [] 85 | perc_test_trials = [] 86 | m_acc = [] 87 | s_acc = [] 88 | 89 | # cycling on all possible fold splits 90 | for idx, current_n_folds in enumerate(n_folds_list): 91 | n_folds.append(current_n_folds) 92 | n_trials.append(current_n_trials) 93 | if learning_type is 'dl': 94 | n_valid_trials.append(int(np.floor(n_trials[idx] * 0.1))) 95 | else: 96 | n_valid_trials.append(0) 97 | n_train_trials.append( 98 | int(np.ceil(n_trials[idx] / n_folds[idx]) * (n_folds[idx] - 1)) - 99 | n_valid_trials[idx]) 100 | n_test_trials.append(n_trials[idx] - n_train_trials[idx] - 101 | n_valid_trials[idx]) 102 | perc_train_trials.append( 103 | np.round(n_train_trials[idx] / n_trials[idx] * 100, 1)) 104 | perc_valid_trials.append( 105 | np.round(n_valid_trials[idx] / n_trials[idx] * 100, 1)) 106 | perc_test_trials.append( 107 | np.round(100 - perc_valid_trials[idx] - perc_train_trials[idx], 1)) 108 | 109 | # cycling on folds 110 | folds_acc = [] 111 | for fold_str in [get_fold_str(x + 1) for x in range(n_folds[idx])]: 112 | file_path = join(folder_paths[idx], 113 | subj, fold_str, 'fold_stats.pickle') 114 | with open(file_path, 'rb') as f: 115 | fold_stats = load(f) 116 | folds_acc.append(fold_stats['test']['acc']) 117 | m_acc.append(np.mean(folds_acc) * 100) 118 | s_acc.append(np.std(folds_acc) * 100) 119 | 120 | # assigning results for this subject 121 | results['n_folds'].append(n_folds) 122 | results['n_trials'].append(n_trials) 123 | results['n_train_trials'].append(n_train_trials) 124 | results['n_valid_trials'].append(n_valid_trials) 125 | results['n_test_trials'].append(n_test_trials) 126 | results['perc_train_trials'].append(perc_train_trials) 127 | results['perc_valid_trials'].append(perc_valid_trials) 128 | results['perc_test_trials'].append(perc_test_trials) 129 | results['m_acc'].append(m_acc) 130 | results['s_acc'].append(s_acc) 131 | 132 | # plotting learning curve for this subject 133 | m_acc = np.array(m_acc) 134 | s_acc = np.array(s_acc) 135 | plot_path = join(savings_dir, subj) 136 | if learning_type is 'dl': 137 | title = '{} learning curve\n'.format(subj) + \ 138 | '({} samples, {} validation samples)'.format( 139 | n_trials[0], n_valid_trials[0]) 140 | else: 141 | title = '{} learning curve\n({} samples)'.format(subj, n_trials[0]) 142 | x_tick_labels = ['{}\n({} folds)'.format(trials, folds) 143 | for trials, folds in zip(n_train_trials, n_folds)] 144 | plt.figure(dpi=100, figsize=(12.8, 7.2), facecolor='w', edgecolor='k') 145 | plt.style.use('seaborn-whitegrid') 146 | plt.errorbar(x=n_folds, y=m_acc, yerr=s_acc, 147 | fmt='-.o', color='b', ecolor='r', 148 | linewidth=2, elinewidth=3, capsize=20, capthick=2) 149 | plt.xlabel('training samples', fontsize=25) 150 | plt.ylabel('accuracy (%)', fontsize=25) 151 | plt.xticks(n_folds, labels=x_tick_labels, fontsize=20) 152 | plt.yticks(fontsize=20) 153 | plt.title(title, fontsize=25) 154 | 155 | # saving figure 156 | savefig(plot_path, bbox_inches='tight') 157 | 158 | # getting data for last plot 159 | n_trials = np.array(results['n_trials']) 160 | n_train_trials = np.array(results['n_train_trials']) 161 | n_valid_trials = np.array(results['n_valid_trials']) 162 | 163 | # averaging data 164 | m_n_trials = int(np.round(np.mean(n_trials, axis=0))[0].tolist()) 165 | s_n_trials = int(np.round(np.std(n_trials, axis=0))[0].tolist()) 166 | m_n_train_trials = np.round(np.mean(n_train_trials, axis=0)).tolist() 167 | s_n_train_trials = np.round(np.std(n_train_trials, axis=0)).tolist() 168 | m_n_train_trials = list(map(int, m_n_train_trials)) 169 | s_n_train_trials = list(map(int, s_n_train_trials)) 170 | m_n_valid_trials = np.round(np.mean(n_valid_trials, axis=0)).tolist() 171 | s_n_valid_trials = np.round(np.std(n_valid_trials, axis=0)).tolist() 172 | m_n_valid_trials = list(map(int, m_n_valid_trials)) 173 | s_n_valid_trials = list(map(int, s_n_valid_trials)) 174 | m_acc = np.mean(np.array(results['m_acc']), axis=0) 175 | s_acc = np.mean(np.array(results['s_acc']), axis=0) 176 | 177 | # plotting learning curve for total mean 178 | plot_path = join(savings_dir, learning_type + '_learning_curve') 179 | if learning_type is 'dl': 180 | # title = '{} learning curve\n({}$\pm${} samples, {}$\pm${} validation ' \ 181 | # 'samples)'.format('average', m_n_trials, s_n_trials, 182 | # m_n_valid_trials[0], s_n_valid_trials[0]) 183 | title = 'totale esempi per soggetto: {}$\pm${}; esempi di validazione: ' \ 184 | '{}$\pm${}'.format(m_n_trials, s_n_trials, 185 | m_n_valid_trials[0], s_n_valid_trials[0]) 186 | else: 187 | # title = '{} learning curve\n({}$\pm${} samples)'.format( 188 | # 'average', m_n_trials, s_n_trials) 189 | title = 'totale esempi per soggetto: {}$\pm${}'.format( 190 | m_n_trials, s_n_trials) 191 | x_tick_labels = ['{}$\pm${}\n({} fold)'.format( 192 | m_n_train_trials[idx], s_n_train_trials[idx], n_folds_list[idx]) 193 | for idx in range(n_trainings)] 194 | plt.figure(dpi=100, figsize=fig_size, facecolor='w', edgecolor='k') 195 | plt.style.use('seaborn-whitegrid') 196 | plt.errorbar(x=n_folds_list, y=m_acc, yerr=s_acc, 197 | fmt='-.o', color='b', ecolor='r', 198 | linewidth=2, elinewidth=3, capsize=20, capthick=2) 199 | plt.xlabel('esempi di training', fontsize=fontsize_1) 200 | plt.ylabel('accuratezza (%)', fontsize=fontsize_1) 201 | plt.xticks(n_folds_list, labels=x_tick_labels, fontsize=fontsize_2) 202 | plt.yticks(fontsize=fontsize_2) 203 | plt.title(title, fontsize=fontsize_1) 204 | 205 | # in case of single_subject, -500,4000 206 | if fold_type is 'single_subject': 207 | if epoching is '-500_4000': 208 | plt.ylim(80, 100) 209 | 210 | # saving figure 211 | savefig(plot_path, bbox_inches='tight') 212 | -------------------------------------------------------------------------------- /sub_routines/ml_cross_validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from hgdecode.classes import CrossValidation 4 | 5 | datetime_dir = '/Users/davidemiani/OneDrive - Alma Mater Studiorum ' \ 6 | 'Università di Bologna/TesiMagistrale_DavideMiani/' \ 7 | 'results/hgdecode/ml/FBCSP_rLDA/2019-01-15_01-54-25' 8 | subj_dirs = os.listdir(datetime_dir) 9 | subj_dirs.sort() 10 | subj_dirs.remove('statistics') 11 | subj_dirs = [os.path.join(datetime_dir, x) for x in subj_dirs] 12 | 13 | name_to_start_codes = OrderedDict([('Right Hand', [1]), 14 | ('Left Hand', [2]), 15 | ('Rest', [3]), 16 | ('Feet', [4])]) 17 | 18 | for subj_dir in subj_dirs: 19 | CrossValidation.cross_validate(subj_results_dir=subj_dir, 20 | label_names=name_to_start_codes) 21 | -------------------------------------------------------------------------------- /sub_routines/t_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from csv import reader 3 | from numpy import array 4 | from scipy.stats import ttest_rel 5 | from hgdecode.utils import get_path 6 | 7 | """ 8 | TRAINING 1 9 | """ 10 | results_dir = None 11 | learning_type = 'dl' 12 | algorithm_or_model_name = None 13 | epoching = '-1000_1000' 14 | fold_type_1 = 'single_subject' 15 | n_folds_list = [12] # must be a list of integer 16 | deprecated = False 17 | balanced_folds = True 18 | folder_paths_1 = [ 19 | get_path( 20 | results_dir=results_dir, 21 | learning_type=learning_type, 22 | algorithm_or_model_name=algorithm_or_model_name, 23 | epoching=epoching, 24 | fold_type=fold_type_1, 25 | n_folds=x, 26 | deprecated=deprecated, 27 | balanced_folds=balanced_folds 28 | ) 29 | for x in n_folds_list 30 | ] 31 | 32 | """ 33 | TRAINING 2 34 | """ 35 | results_dir = None 36 | learning_type = 'ml' 37 | algorithm_or_model_name = None 38 | epoching = '-500_4000' 39 | fold_type_2 = 'single_subject' 40 | n_folds_list = [12] # must be a list of integer 41 | deprecated = False 42 | balanced_folds = True 43 | folder_paths_2 = [ 44 | get_path( 45 | results_dir=results_dir, 46 | learning_type=learning_type, 47 | algorithm_or_model_name=algorithm_or_model_name, 48 | epoching=epoching, 49 | fold_type=fold_type_2, 50 | n_folds=x, 51 | deprecated=deprecated, 52 | balanced_folds=balanced_folds 53 | ) 54 | for x in n_folds_list 55 | ] 56 | 57 | """ 58 | T-TESTING 59 | """ 60 | for training_1, training_2 in zip(folder_paths_1, folder_paths_2): 61 | # loading training_1 accuracies 62 | training_1_acc_csv_path = os.path.join(training_1, 63 | 'statistics', 'tables', 'acc.csv') 64 | with open(training_1_acc_csv_path) as f: 65 | training_1_csv = list(reader(f)) 66 | if fold_type_1 == 'cross_subject': 67 | training_1_accs = array(list(map(float, training_1_csv[1][:-2]))) 68 | else: 69 | training_1_accs = array([ 70 | float(training_1_csv[x][-2]) for x in range(1, len(training_1_csv)) 71 | ]) 72 | 73 | # loading training_2 accuracies 74 | training_2_acc_csv_path = os.path.join(training_2, 75 | 'statistics', 'tables', 'acc.csv') 76 | with open(training_2_acc_csv_path) as f: 77 | training_2_csv = list(reader(f)) 78 | if fold_type_2 == 'cross_subject': 79 | training_2_accs = array(list(map(float, training_2_csv[1][:-2]))) 80 | else: 81 | training_2_accs = array([ 82 | float(training_2_csv[x][-2]) for x in range(1, len(training_2_csv)) 83 | ]) 84 | 85 | # running t-test 86 | statistic, p_value = ttest_rel(training_1_accs, training_2_accs) 87 | print(p_value) 88 | -------------------------------------------------------------------------------- /sub_routines/transfer_learning_curve.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from pylab import savefig 4 | from pickle import load 5 | from os.path import join 6 | from os.path import dirname 7 | from hgdecode.utils import listdir2 8 | from hgdecode.utils import touch_dir 9 | from hgdecode.utils import get_subj_str 10 | from hgdecode.utils import get_fold_str 11 | from hgdecode.utils import get_path 12 | 13 | 14 | def get_fold_number(folder_path, subj_str): 15 | return len(listdir2(join(folder_path, subj_str))) 16 | 17 | 18 | """ 19 | SET HERE YOUR PARAMETERS 20 | """ 21 | # to find file parameters 22 | results_dir = None 23 | learning_type = 'dl' 24 | algorithm_or_model_name = None 25 | epoching = '-1000_1000' 26 | fold_type = 'transfer_learning' 27 | train_size_list = [4, 8, 16, 32, 64, 128] # must be a list of integer 28 | deprecated = False 29 | 30 | fontsize_1 = 35 31 | fontsize_2 = 27.5 32 | fig_size = (22, 7.5) 33 | 34 | """ 35 | GETTING PATHS 36 | """ 37 | # trainings stuff 38 | folder_paths = [ 39 | get_path( 40 | results_dir=results_dir, 41 | learning_type=learning_type, 42 | algorithm_or_model_name=algorithm_or_model_name, 43 | epoching=epoching, 44 | fold_type=fold_type, 45 | n_folds=x, 46 | deprecated=deprecated 47 | ) 48 | for x in train_size_list 49 | ] 50 | n_trainings = len(folder_paths) 51 | 52 | # saving stuff 53 | savings_dir = join(dirname(dirname(folder_paths[0])), 'learning_curve') 54 | touch_dir(savings_dir) 55 | 56 | """ 57 | SUBJECTS STUFF 58 | """ 59 | # subject stuff 60 | n_trials_list = [480, 973, 1040, 1057, 880, 1040, 1040, 814, 1040, 1040, 61 | 1040, 1040, 950, 1040] 62 | n_subjects = len(n_trials_list) 63 | subj_str_list = [get_subj_str(x + 1) for x in range(n_subjects)] 64 | 65 | """ 66 | COMPUTATION STARTS HERE 67 | """ 68 | # pre-allocating results dictionary 69 | results = { 70 | 'n_folds': [], 71 | 'n_trials': [], 72 | 'n_train_trials': [], 73 | 'n_valid_trials': [], 74 | 'n_test_trials': [], 75 | 'perc_train_trials': [], 76 | 'perc_valid_trials': [], 77 | 'perc_test_trials': [], 78 | 'm_acc': [], 79 | 's_acc': [] 80 | } 81 | 82 | # cycling on subject 83 | for subj, current_n_trials in zip(subj_str_list, n_trials_list): 84 | n_folds = [] 85 | n_trials = [] 86 | n_train_trials = [] 87 | n_valid_trials = [] 88 | n_test_trials = [] 89 | perc_train_trials = [] 90 | perc_valid_trials = [] 91 | perc_test_trials = [] 92 | m_acc = [] 93 | s_acc = [] 94 | 95 | # cycling on all possible fold splits 96 | for idx, current_train_size in enumerate(train_size_list): 97 | n_folds.append(get_fold_number(folder_paths[idx], subj)) 98 | n_trials.append(current_n_trials) 99 | n_valid_trials.append(int(np.floor(n_trials[idx] * 0.1))) 100 | n_train_trials.append(current_train_size) 101 | n_test_trials.append(n_trials[idx] - n_train_trials[idx] - 102 | n_valid_trials[idx]) 103 | perc_train_trials.append( 104 | np.round(n_train_trials[idx] / n_trials[idx] * 100, 1)) 105 | perc_valid_trials.append( 106 | np.round(n_valid_trials[idx] / n_trials[idx] * 100, 1)) 107 | perc_test_trials.append( 108 | np.round(100 - perc_valid_trials[idx] - perc_train_trials[idx], 1)) 109 | 110 | # cycling on folds 111 | folds_acc = [] 112 | for fold_str in [get_fold_str(x + 1) for x in range(n_folds[idx])]: 113 | file_path = join(folder_paths[idx], 114 | subj, fold_str, 'fold_stats.pickle') 115 | with open(file_path, 'rb') as f: 116 | fold_stats = load(f) 117 | folds_acc.append(fold_stats['test']['acc']) 118 | m_acc.append(np.mean(folds_acc) * 100) 119 | s_acc.append(np.std(folds_acc) * 100) 120 | 121 | # assigning results for this subject 122 | results['n_folds'].append(n_folds) 123 | results['n_trials'].append(n_trials) 124 | results['n_train_trials'].append(n_train_trials) 125 | results['n_valid_trials'].append(n_valid_trials) 126 | results['n_test_trials'].append(n_test_trials) 127 | results['perc_train_trials'].append(perc_train_trials) 128 | results['perc_valid_trials'].append(perc_valid_trials) 129 | results['perc_test_trials'].append(perc_test_trials) 130 | results['m_acc'].append(m_acc) 131 | results['s_acc'].append(s_acc) 132 | 133 | # plotting learning curve for this subject 134 | m_acc = np.array(m_acc) 135 | s_acc = np.array(s_acc) 136 | plot_path = join(savings_dir, subj) 137 | if learning_type is 'dl': 138 | title = '{} transfer learning curve\n'.format(subj) + \ 139 | '({} samples, {} validation samples)'.format( 140 | n_trials[0], n_valid_trials[0]) 141 | else: 142 | title = '{} transfer learning curve\n({} samples)'.format(subj, 143 | n_trials[0]) 144 | x_tick_labels = ['{}\n({} folds)'.format(trials, folds) 145 | for trials, folds in zip(n_train_trials, n_folds)] 146 | plt.figure(dpi=100, figsize=(12.8, 7.2), facecolor='w', edgecolor='k') 147 | plt.style.use('seaborn-whitegrid') 148 | plt.errorbar(x=[2, 4, 6, 8, 10, 12], y=m_acc, yerr=s_acc, 149 | fmt='-.o', color='b', ecolor='r', 150 | linewidth=2, elinewidth=3, capsize=20, capthick=2) 151 | plt.xlabel('training samples', fontsize=25) 152 | plt.ylabel('accuracy (%)', fontsize=25) 153 | plt.xticks([2, 4, 6, 8, 10, 12], labels=x_tick_labels, fontsize=20) 154 | plt.yticks(fontsize=20) 155 | plt.title(title, fontsize=25) 156 | 157 | # saving figure 158 | savefig(plot_path, bbox_inches='tight') 159 | 160 | # getting data for last plot 161 | n_trials = np.array(results['n_trials']) 162 | n_train_trials = np.array(results['n_train_trials']) 163 | n_valid_trials = np.array(results['n_valid_trials']) 164 | n_folds = np.array(results['n_folds']) 165 | 166 | # averaging data 167 | m_n_trials = int(np.round(np.mean(n_trials, axis=0))[0].tolist()) 168 | s_n_trials = int(np.round(np.std(n_trials, axis=0))[0].tolist()) 169 | m_n_train_trials = train_size_list 170 | m_n_valid_trials = np.round(np.mean(n_valid_trials, axis=0)).tolist() 171 | s_n_valid_trials = np.round(np.std(n_valid_trials, axis=0)).tolist() 172 | m_n_valid_trials = list(map(int, m_n_valid_trials)) 173 | s_n_valid_trials = list(map(int, s_n_valid_trials)) 174 | m_n_folds = np.round(np.mean(n_folds, axis=0)).tolist() 175 | s_n_folds = np.round(np.std(n_folds, axis=0)).tolist() 176 | m_n_folds = list(map(int, m_n_folds)) 177 | s_n_folds = list(map(int, s_n_folds)) 178 | m_acc = np.mean(np.array(results['m_acc']), axis=0) 179 | s_acc = np.mean(np.array(results['s_acc']), axis=0) 180 | 181 | # plotting learning curve for total mean 182 | plot_path = join(savings_dir, 'transfer_learning_curve') 183 | if learning_type is 'dl': 184 | # title = '{} learning curve\n({}$\pm${} samples, {}$\pm${} validation ' \ 185 | # 'samples)'.format('average', m_n_trials, s_n_trials, 186 | # m_n_valid_trials[0], s_n_valid_trials[0]) 187 | title = 'totale esempi per soggetto: {}$\pm${}; esempi di validazione: ' \ 188 | '{}$\pm${}'.format(m_n_trials, s_n_trials, 189 | m_n_valid_trials[0], s_n_valid_trials[0]) 190 | else: 191 | title = '{} learning curve\n({}$\pm${} samples)'.format( 192 | 'average', m_n_trials, s_n_trials) 193 | x_tick_labels = ['{}\n({}$\pm${} fold)'.format( 194 | m_n_train_trials[idx], m_n_folds[idx], s_n_folds[idx]) 195 | for idx in range(n_trainings)] 196 | plt.figure(dpi=100, figsize=fig_size, facecolor='w', edgecolor='k') 197 | plt.style.use('seaborn-whitegrid') 198 | plt.errorbar(x=[2, 4, 6, 8, 10, 12], y=m_acc, yerr=s_acc, 199 | fmt='-.o', color='b', ecolor='r', 200 | linewidth=2, elinewidth=3, capsize=20, capthick=2) 201 | plt.xlabel('esempi di training', fontsize=fontsize_1) 202 | plt.ylabel('accuratezza (%)', fontsize=fontsize_1) 203 | plt.xticks([2, 4, 6, 8, 10, 12], labels=x_tick_labels, fontsize=fontsize_2) 204 | plt.yticks(fontsize=fontsize_2) 205 | plt.title(title, fontsize=fontsize_1) 206 | 207 | # in case of single_subject, -500,4000 208 | if fold_type is 'single_subject': 209 | if epoching is '-500_4000': 210 | plt.ylim(80, 100) 211 | 212 | # saving figure 213 | savefig(plot_path, bbox_inches='tight') 214 | -------------------------------------------------------------------------------- /sub_routines/transfer_learning_curve_frozen_layers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, dirname 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from pylab import savefig 6 | from csv import reader 7 | from hgdecode.utils import get_path, touch_dir 8 | 9 | """ 10 | SET HERE YOUR PARAMETERS 11 | """ 12 | ival = (-500, 4000) 13 | frozen_layers_list = [1, 2, 3, 4, 5, -1, -2, -3, -4, -5, 6] 14 | 15 | fontsize_1 = 35 16 | fontsize_2 = 27.5 17 | fig_size = (22, 7.5) 18 | 19 | """ 20 | GETTING PATHS 21 | """ 22 | folder_paths = [get_path( 23 | results_dir=None, 24 | learning_type='dl', 25 | algorithm_or_model_name=None, 26 | epoching=ival, 27 | fold_type='transfer_learning', 28 | n_folds=128, 29 | deprecated=False 30 | )] 31 | 32 | folder_paths += [ 33 | get_path( 34 | results_dir=None, 35 | learning_type='dl', 36 | algorithm_or_model_name='DeepConvNet', 37 | epoching=ival, 38 | fold_type='transfer_learning_frozen', 39 | n_folds=x, 40 | deprecated=False 41 | ) 42 | for x in frozen_layers_list 43 | ] 44 | 45 | # getting file_path 46 | csv_paths = [ 47 | os.path.join(x, 'statistics', 'tables', 'acc.csv') 48 | for x in folder_paths 49 | ] 50 | 51 | """ 52 | SETTING PATHS 53 | """ 54 | # saving stuff 55 | savings_dir = join(dirname( 56 | dirname(dirname(folder_paths[1]))), 'learning_curve') 57 | touch_dir(savings_dir) 58 | 59 | """ 60 | GETTING DATA FROM CSV FILES 61 | """ 62 | subj_data = [] 63 | mean_data = [] 64 | stdd_data = [] 65 | for idx, csv_path in enumerate(csv_paths): 66 | with open(csv_path) as f: 67 | csv = list(reader(f)) 68 | csv = csv[1:] 69 | csv = [ 70 | list(map(float, csv[x])) 71 | for x in range(len(csv)) 72 | ] 73 | temp_data = [] 74 | for csv_line in csv: 75 | temp_data.append(csv_line[-2]) 76 | subj_data.append(temp_data) 77 | mean_data.append(np.mean(temp_data)) 78 | stdd_data.append(np.std(temp_data)) 79 | 80 | """ 81 | DEMIXING STUFF 82 | """ 83 | pos_layers_m = np.round(np.array(mean_data[0:6] + [mean_data[-1]]) * 100, 1) 84 | neg_layers_m = np.round(np.array([mean_data[0]] + mean_data[-6:]) * 100, 1) 85 | pos_layers_s = np.round(np.array(stdd_data[0:6] + [stdd_data[-1]]) * 100, 1) 86 | neg_layers_s = np.round(np.array([stdd_data[0]] + stdd_data[-6:]) * 100, 1) 87 | 88 | """ 89 | PLOTTING POS 90 | """ 91 | plt.figure(dpi=100, figsize=fig_size, facecolor='w', edgecolor='k') 92 | plt.style.use('seaborn-whitegrid') 93 | plt.errorbar(x=[0, 1, 2, 3, 4, 5, 6], 94 | y=pos_layers_m, 95 | yerr=pos_layers_s, 96 | fmt='-.o', color='b', ecolor='r', 97 | linewidth=2, elinewidth=3, capsize=20, capthick=2) 98 | plt.xlabel('indice di congelamento', 99 | fontsize=fontsize_1) 100 | plt.ylabel('accuratezza (%)', fontsize=fontsize_1) 101 | plt.xticks(fontsize=fontsize_2) 102 | plt.yticks(fontsize=fontsize_2) 103 | plt.title('esempi di training: 128', fontsize=fontsize_1) 104 | savefig( 105 | join(savings_dir, 'frozen_layers_learning_curve_1'), bbox_inches='tight') 106 | 107 | """ 108 | PLOTTING NEG 109 | """ 110 | plt.figure(dpi=100, figsize=fig_size, facecolor='w', edgecolor='k') 111 | plt.style.use('seaborn-whitegrid') 112 | plt.errorbar(x=[0, -1, -2, -3, -4, -5, -6], 113 | y=neg_layers_m, 114 | yerr=neg_layers_s, 115 | fmt='-.o', color='b', ecolor='r', 116 | linewidth=2, elinewidth=3, capsize=20, capthick=2) 117 | plt.xlabel('indice di congelamento', 118 | fontsize=fontsize_1) 119 | plt.ylabel('accuratezza (%)', fontsize=fontsize_1) 120 | plt.xticks(fontsize=fontsize_2) 121 | plt.yticks(fontsize=fontsize_2) 122 | plt.title('esempi di training: 128', fontsize=fontsize_1) 123 | savefig( 124 | join(savings_dir, 'frozen_layers_learning_curve_2'), bbox_inches='tight') 125 | -------------------------------------------------------------------------------- /transfer_learning.py: -------------------------------------------------------------------------------- 1 | from os import getcwd 2 | from numpy import ceil 3 | from os.path import join 4 | from os.path import dirname 5 | from collections import OrderedDict 6 | from numpy.random import RandomState 7 | from hgdecode.utils import get_path 8 | from hgdecode.utils import create_log 9 | from hgdecode.utils import print_manager 10 | from hgdecode.utils import clear_all_models 11 | from hgdecode.loaders import dl_loader 12 | from hgdecode.classes import CrossValidation 13 | from hgdecode.experiments import DLExperiment 14 | from keras import backend as K 15 | 16 | """ 17 | SETTING PARAMETERS 18 | ------------------ 19 | """ 20 | 21 | # setting model_name and validation_frac 22 | model_name = 'DeepConvNet' # Schirrmeister: 'DeepConvNet' or 'ShallowNet' 23 | 24 | # setting channel_names 25 | channel_names = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 26 | 'CP5', 'CP1', 'CP2', 'CP6', 27 | 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 28 | 'CP3', 'CPz', 'CP4', 29 | 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 30 | 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h', 31 | 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 32 | 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h', 33 | 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 34 | 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h'] 35 | 36 | # setting data_dir & results_dir 37 | data_dir = join(dirname(dirname(getcwd())), 'datasets', 'High-Gamma') 38 | results_dir = join(dirname(dirname(getcwd())), 'results', 'hgdecode') 39 | 40 | # setting name_to_start_codes 41 | name_to_start_codes = OrderedDict([('Right Hand', [1]), 42 | ('Left Hand', [2]), 43 | ('Rest', [3]), 44 | ('Feet', [4])]) 45 | 46 | # setting subject_ids 47 | subject_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14) 48 | 49 | # setting random_state 50 | random_state = RandomState(1234) 51 | 52 | # setting fold_size: this will be the number of trials for training, 53 | # so it must be multiple of 4 54 | fold_size = 8 # must be integer 55 | validation_frac = 0.1 56 | 57 | # setting frozen_layers 58 | layers_to_freeze = 0 # can be between -6 and 6 for DeepConvNet 59 | 60 | # other hyper-parameters 61 | dropout_rate = 0.6 62 | learning_rate = 2 * 1e-5 63 | epochs = 100 64 | ival = (-1000, 1000) 65 | 66 | """ 67 | GETTING CROSS-SUBJECT MODELS DIR PATH 68 | ------------------------------------- 69 | """ 70 | # setting cross_subj_dir_path: data from cross-subj computation are stored here 71 | learning_type = 'dl' 72 | algorithm_or_model_name = None 73 | epoching = ival 74 | fold_type = 'cross_subject' 75 | n_folds = None 76 | deprecated = False 77 | cross_subj_dir_path = get_path( 78 | results_dir=dirname(results_dir), 79 | learning_type=learning_type, 80 | algorithm_or_model_name=algorithm_or_model_name, 81 | epoching=epoching, 82 | fold_type=fold_type, 83 | n_folds=n_folds, 84 | deprecated=deprecated 85 | ) 86 | 87 | """ 88 | COMPUTATION 89 | ----------- 90 | """ 91 | for subject_id in subject_ids: 92 | # creating a log object 93 | subj_results_dir = create_log( 94 | results_dir=results_dir, 95 | learning_type='dl', 96 | algorithm_or_model_name=model_name, 97 | subject_id=subject_id, 98 | output_on_file=False, 99 | use_last_result_directory=False 100 | ) 101 | 102 | # loading epoched signal 103 | epo = dl_loader( 104 | data_dir=data_dir, 105 | name_to_start_codes=name_to_start_codes, 106 | channel_names=channel_names, 107 | subject_id=subject_id, 108 | resampling_freq=250, # Schirrmeister: 250 109 | clean_ival_ms=ival, # Schirrmeister: (0, 4000) 110 | epoch_ival_ms=ival, # Schirrmeister: (-500, 4000) 111 | train_test_split=True, # Schirrmeister: True 112 | clean_on_all_channels=False # Schirrmeister: True 113 | ) 114 | 115 | # if fold_size is not a multiple of 4, putting it to the nearest 116 | fold_size = int(ceil(fold_size / 4) * 4) 117 | 118 | # computing batch_size to be... 119 | if fold_size <= 64: 120 | batch_size = fold_size 121 | else: 122 | batch_size = 64 123 | 124 | # I don't think this is a good idea: 125 | # validation_size is equal to fold_size 126 | # validation_size = fold_size 127 | 128 | # creating CrossValidation class instance 129 | cross_validation = CrossValidation( 130 | X=epo.X, 131 | y=epo.y, 132 | fold_size=fold_size, 133 | # validation_size=validation_size, 134 | validation_frac=validation_frac, 135 | random_state=random_state, shuffle=True, 136 | swap_train_test=True, 137 | ) 138 | cross_validation.balance_train_set(train_size=fold_size) 139 | 140 | # pre-allocating experiment 141 | exp = None 142 | 143 | # cycling on folds for cross validation 144 | for fold_idx, current_fold in enumerate(cross_validation.folds): 145 | # clearing TF graph (https://github.com/keras-team/keras/issues/3579) 146 | print_manager('CLEARING KERAS BACKEND', print_style='double-dashed') 147 | K.clear_session() 148 | print_manager(print_style='last', bottom_return=1) 149 | 150 | # printing fold information 151 | print_manager( 152 | 'SUBJECT {}, FOLD {}'.format(subject_id, fold_idx + 1), 153 | print_style='double-dashed' 154 | ) 155 | cross_validation.print_fold_classes(fold_idx) 156 | print_manager(print_style='last', bottom_return=1) 157 | 158 | # creating EEGDataset for current fold 159 | dataset = cross_validation.create_dataset(fold=current_fold) 160 | 161 | # creating experiment instance 162 | exp = DLExperiment( 163 | # non-default inputs 164 | dataset=dataset, 165 | model_name=model_name, 166 | results_dir=results_dir, 167 | subj_results_dir=subj_results_dir, 168 | name_to_start_codes=name_to_start_codes, 169 | random_state=random_state, 170 | fold_idx=fold_idx, 171 | 172 | # hyperparameters 173 | dropout_rate=dropout_rate, # Schirrmeister: 0.5 174 | learning_rate=learning_rate, # Schirrmeister: ? 175 | batch_size=batch_size, # Schirrmeister: 512 176 | epochs=epochs, # Schirrmeister: ? 177 | early_stopping=False, # Schirrmeister: ? 178 | monitor='val_acc', # Schirrmeister: ? 179 | min_delta=0.0001, # Schirrmeister: ? 180 | patience=5, # Schirrmeister: ? 181 | loss='categorical_crossentropy', # Schirrmeister: ad hoc 182 | optimizer='Adam', # Schirrmeister: Adam 183 | shuffle=True, # Schirrmeister: ? 184 | crop_sample_size=None, # Schirrmeister: 1125 185 | crop_step=None, # Schirrmeister: 1 186 | 187 | # other parameters 188 | subject_id=subject_id, 189 | data_generator=False, # Schirrmeister: True 190 | save_model_at_each_epoch=False 191 | ) 192 | 193 | # loading model weights from cross-subject pre-trained model 194 | exp.prepare_for_transfer_learning( 195 | cross_subj_dir_path=cross_subj_dir_path, 196 | subject_id=subject_id, 197 | train_anyway=True 198 | ) 199 | 200 | # freezing layers 201 | exp.freeze_layers(layers_to_freeze=layers_to_freeze) 202 | 203 | # training 204 | exp.train() 205 | 206 | # computing cross-validation 207 | if exp is not None: 208 | cross_validation.cross_validate( 209 | subj_results_dir=exp.subj_results_dir, 210 | label_names=name_to_start_codes) 211 | 212 | # clearing all models (they are not useful once we have results) 213 | clear_all_models(subj_results_dir) 214 | --------------------------------------------------------------------------------