├── Data ├── .gitignore └── preprocess.py ├── Models └── .gitignore ├── ChannelSelection ├── __pycache__ │ ├── loader.cpython-37.pyc │ ├── loader.cpython-38.pyc │ ├── models.cpython-37.pyc │ └── models.cpython-38.pyc ├── loader.py ├── models.py └── selectNchannels.py └── README.md /Data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /Models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /ChannelSelection/__pycache__/loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strypsteen/Gumbel-Channel-Selection/HEAD/ChannelSelection/__pycache__/loader.cpython-37.pyc -------------------------------------------------------------------------------- /ChannelSelection/__pycache__/loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strypsteen/Gumbel-Channel-Selection/HEAD/ChannelSelection/__pycache__/loader.cpython-38.pyc -------------------------------------------------------------------------------- /ChannelSelection/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strypsteen/Gumbel-Channel-Selection/HEAD/ChannelSelection/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /ChannelSelection/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strypsteen/Gumbel-Channel-Selection/HEAD/ChannelSelection/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EEG Channel Selection with Gumbel-softmax 2 | 3 | ## About 4 | 5 | This Python project is the PyTorch implementation of a concrete EEG channel selection layer based on the Gumbel-softmax method. This layer can be placed in front of any deep neural network architecture to jointly learn the optimal subset of EEG channels for the given task and the network weights. This layer of composed of selection neurons, that each use a continuous relaxation of a discrete distribution across the input channels to learn the optimal one-hot weight vector to select input channels instead of linearly combining them. 6 | 7 | ## Usage 8 | 9 | This implementation operates on the dataset described in [1]. To download this data, follow the instructions at https://github.com/robintibor/high-gamma-dataset and place the data it in the Data folder. Then, convert these files from rad hdf5-files to preprocessed npy-files by installing BrainDecode 0.4.85 as described at https://robintibor.github.io/braindecode/ and running preprocess.py. 10 | 11 | To run the code, install Pytorch (https://pytorch.org/) and run selectNchannels.py 12 | 13 | ## References 14 | 15 | [1] R. T. Schirrmeister, J. T. Springenberg, L. D. J. Fiederer, M. Glasstetter, K. Eggensperger, M. Tangermann, F. Hutter, W. Burgard, and T. Ball, “Deep learning with convolutional neural networks for EEG decoding and visualization,” Human brain mapping, vol. 38, no. 11, pp. 5391– 5420, 2017. 16 | -------------------------------------------------------------------------------- /ChannelSelection/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | 6 | def all_subject_loader_HGD(batch_size,train_split,path): 7 | 8 | num_subjects = 14 9 | 10 | #Create dataset 11 | tr_ds=[] 12 | val_ds=[] 13 | test_ds=[] 14 | 15 | for k in range(num_subjects): 16 | #Load training data 17 | traindatapath = os.path.join(path,str(k+1)+"traindata.npy") 18 | trainlabelpath = os.path.join(path,str(k+1)+"trainlabel.npy") 19 | train_eeg_data = torch.Tensor(np.load(traindatapath)) 20 | train_labels = torch.LongTensor(np.load(trainlabelpath)) 21 | 22 | split = round(train_split*train_eeg_data.size(0)) 23 | 24 | for i in range(train_eeg_data.size(0)): 25 | x = train_eeg_data[i,:,:] 26 | x=x.view(1,x.size(0),x.size(1)) 27 | y = train_labels[i] 28 | if(i<=split): 29 | tr_ds.append([x,y]) 30 | else: 31 | val_ds.append([x,y]) 32 | 33 | #Load test data 34 | testdatapath = path + str(k+1)+"testdata.npy" 35 | testlabelpath = path + str(k+1)+"testlabel.npy" 36 | test_eeg_data = torch.Tensor(np.load(testdatapath)) 37 | test_labels = torch.LongTensor(np.load(testlabelpath)) 38 | 39 | for i in range(test_eeg_data.size(0)): 40 | x = test_eeg_data[i,:,:] 41 | x=x.view(1,x.size(0),x.size(1)) 42 | y = test_labels[i] 43 | test_ds.append([x,y]) 44 | 45 | trainloader = torch.utils.data.DataLoader(tr_ds, batch_size=batch_size, 46 | shuffle=True) 47 | valloader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, 48 | shuffle=False) 49 | testloader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, 50 | shuffle=False) 51 | return trainloader,valloader,testloader 52 | 53 | def within_subject_loader_HGD(subject,batch_size,train_split,path): 54 | 55 | traindatapath = path + str(subject)+"traindata.npy" 56 | trainlabelpath = path + str(subject)+"trainlabel.npy" 57 | train_eeg_data = torch.Tensor(np.load(traindatapath)) 58 | train_labels = torch.LongTensor(np.load(trainlabelpath)) 59 | 60 | tr_ds=[] 61 | val_ds = [] 62 | split = round(train_split*train_eeg_data.size(0)) 63 | for i in range(train_eeg_data.size(0)): 64 | x = train_eeg_data[i,:,:] 65 | #x=x[::2,:] 66 | x=x.view(1,x.size(0),x.size(1)) 67 | y = train_labels[i] 68 | if(i<= split): 69 | tr_ds.append([x,y]) 70 | else: 71 | val_ds.append([x,y]) 72 | 73 | 74 | testdatapath = path + str(subject)+"testdata.npy" 75 | testlabelpath = path + str(subject)+"testlabel.npy" 76 | test_eeg_data = torch.Tensor(np.load(testdatapath)) 77 | test_labels = torch.LongTensor(np.load(testlabelpath)) 78 | 79 | test_ds=[] 80 | for i in range(test_eeg_data.size(0)): 81 | x = test_eeg_data[i,:,:] 82 | #x=x[::2,:] 83 | x=x.view(1,x.size(0),x.size(1)) 84 | y = test_labels[i] 85 | test_ds.append([x,y]) 86 | 87 | 88 | trainloader = torch.utils.data.DataLoader(tr_ds, batch_size=batch_size, 89 | shuffle=False) 90 | valloader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, 91 | shuffle=False) 92 | testloader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, 93 | shuffle=False) 94 | return trainloader,valloader,testloader 95 | -------------------------------------------------------------------------------- /Data/preprocess.py: -------------------------------------------------------------------------------- 1 | #Code adapted from https://github/com/robintibor/high-gamma-dataset 2 | import logging 3 | import sys 4 | import os.path 5 | from collections import OrderedDict 6 | import numpy as np 7 | 8 | from braindecode.datasets.bbci import BBCIDataset 9 | from braindecode.datautil.signalproc import highpass_cnt 10 | import torch.nn.functional as F 11 | import torch as th 12 | from torch import optim 13 | from braindecode.torch_ext.util import set_random_seeds 14 | from braindecode.models.deep4 import Deep4Net 15 | from braindecode.models.shallow_fbcsp import ShallowFBCSPNet 16 | from braindecode.models.util import to_dense_prediction_model 17 | from braindecode.experiments.experiment import Experiment 18 | from braindecode.torch_ext.util import np_to_var 19 | from braindecode.datautil.iterators import CropsFromTrialsIterator 20 | from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or 21 | from braindecode.torch_ext.constraints import MaxNormDefaultConstraint 22 | from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \ 23 | RuntimeMonitor, CroppedTrialMisclassMonitor 24 | 25 | from braindecode.datautil.splitters import split_into_two_sets 26 | from braindecode.datautil.trial_segment import \ 27 | create_signal_target_from_raw_mne 28 | from braindecode.mne_ext.signalproc import mne_apply, resample_cnt 29 | from braindecode.datautil.signalproc import exponential_running_standardize 30 | 31 | from braindecode.datasets.sensor_positions import ( 32 | CHANNEL_10_20_APPROX, 33 | get_channelpos, 34 | ) 35 | 36 | import scipy.io as sio 37 | import mne 38 | 39 | 40 | log = logging.getLogger(__name__) 41 | log.setLevel('DEBUG') 42 | 43 | 44 | def load_bbci_data(filename, low_cut_hz): 45 | load_sensor_names = None 46 | loader = BBCIDataset(filename, load_sensor_names=load_sensor_names) 47 | 48 | 49 | log.info("Loading data...") 50 | cnt = loader.load() 51 | 52 | # Cleaning: First find all trials that have absolute microvolt values 53 | # larger than +- 800 inside them and remember them for removal later 54 | log.info("Cutting trials...") 55 | 56 | marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2],), 57 | ('Rest', [3]), ('Feet', [4])]) 58 | clean_ival = [0, 4000] 59 | 60 | set_for_cleaning = create_signal_target_from_raw_mne(cnt, marker_def, 61 | clean_ival) 62 | 63 | clean_trial_mask = np.max(np.abs(set_for_cleaning.X), axis=(1, 2)) < 800 64 | 65 | log.info("Clean trials: {:3d} of {:3d} ({:5.1f}%)".format( 66 | np.sum(clean_trial_mask), 67 | len(set_for_cleaning.X), 68 | np.mean(clean_trial_mask) * 100)) 69 | 70 | # now pick only sensors with C in their name 71 | # as they cover motor cortex 72 | C_sensors = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 'CP5', 73 | 'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 74 | 'C6', 75 | 'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 76 | 'FCC5h', 77 | 'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 78 | 'CPP5h', 79 | 'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 80 | 'CCP1h', 81 | 'CCP2h', 'CPP1h', 'CPP2h'] 82 | 83 | cnt = cnt.pick_channels(C_sensors) 84 | 85 | # Further preprocessings 86 | log.info("Resampling...") 87 | cnt = resample_cnt(cnt, 250.0) 88 | 89 | print("REREFERENCING") 90 | 91 | log.info("Highpassing...") 92 | cnt = mne_apply(lambda a: highpass_cnt(a, low_cut_hz, cnt.info['sfreq'], filt_order=3, axis=1),cnt) 93 | log.info("Standardizing...") 94 | cnt = mne_apply(lambda a: exponential_running_standardize(a.T, factor_new=1e-3,init_block_size=1000,eps=1e-4).T,cnt) 95 | 96 | # Trial interval, start at -500 already, since improved decoding for networks 97 | ival = [-500, 4000] 98 | 99 | dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival) 100 | 101 | dataset.X = dataset.X[clean_trial_mask] 102 | dataset.y = dataset.y[clean_trial_mask] 103 | return dataset.X, dataset.y 104 | 105 | 106 | low_cut_hz=4 107 | for i in range(14): 108 | print("Start Train Data Subject " + str(i+1) ) 109 | filename = "train/" + str(i+1) + ".mat" 110 | savenamedata = "train/" + str(i+1) + "traindata.npy" 111 | savenamelabel = "train/"+ str(i+1) + "trainlabel.npy" 112 | X,y = load_bbci_data(filename, low_cut_hz) 113 | np.save(savenamedata,X) 114 | np.save(savenamelabel,y) 115 | 116 | for i in range(14): 117 | print("Start Test Data Subject " + str(i+1) ) 118 | filename = "test/" + str(i+1) + ".mat" 119 | savenamedata = "test/" + str(i+1) + "testdata.npy" 120 | savenamelabel = "test/"+ str(i+1) + "testlabel.npy" 121 | X,y = load_bbci_data(filename, low_cut_hz) 122 | np.save(savenamedata,X) 123 | np.save(savenamelabel,y) 124 | -------------------------------------------------------------------------------- /ChannelSelection/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | import math 7 | import scipy.io as sio 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | 12 | ##PRUNED TRAINING 13 | 14 | limit_a, limit_b, epsilon = -.1, 1.1, 1e-6 15 | 16 | ##SELECT N CHANNELS 17 | 18 | def init_weights(m): 19 | if (type(m) == nn.Linear or type(m) == nn.Conv2d): 20 | torch.nn.init.xavier_uniform_(m.weight) 21 | 22 | class MSFBCNN(nn.Module): 23 | def __init__(self,input_dim,output_dim,FT=10): 24 | super(MSFBCNN, self).__init__() 25 | self.T = input_dim[1] 26 | self.FT = FT 27 | self.D = 1 28 | self.FS = self.FT*self.D 29 | self.C=input_dim[0] 30 | self.output_dim = output_dim 31 | 32 | # Parallel temporal convolutions 33 | self.conv1a = nn.Conv2d(1, self.FT, (1, 65), padding = (0,32),bias=False) 34 | self.conv1b = nn.Conv2d(1, self.FT, (1, 41), padding = (0,20),bias=False) 35 | self.conv1c = nn.Conv2d(1, self.FT, (1, 27), padding = (0,13),bias=False) 36 | self.conv1d = nn.Conv2d(1, self.FT, (1, 17), padding = (0,8),bias=False) 37 | 38 | self.batchnorm1 = nn.BatchNorm2d(4*self.FT, False) 39 | 40 | # Spatial convolution 41 | self.conv2 = nn.Conv2d(4*self.FT, self.FS, (self.C,1),padding=(0,0),groups=1,bias=False) 42 | self.batchnorm2 = nn.BatchNorm2d(self.FS, False) 43 | 44 | #Temporal average pooling 45 | self.pooling2 = nn.AvgPool2d(kernel_size=(1, 75),stride=(1,15),padding=(0,0)) 46 | 47 | self.drop=nn.Dropout(0.5) 48 | 49 | #Classification 50 | self.fc1 = nn.Linear(self.FS*math.ceil(1+(self.T-75)/15), self.output_dim) 51 | 52 | def forward(self, x): 53 | 54 | # Layer 1 55 | x1 = self.conv1a(x); 56 | x2 = self.conv1b(x); 57 | x3 = self.conv1c(x); 58 | x4 = self.conv1d(x); 59 | 60 | x = torch.cat([x1,x2,x3,x4],dim=1) 61 | x = self.batchnorm1(x) 62 | 63 | # Layer 2 64 | x = torch.pow(self.batchnorm2(self.conv2(x)),2) 65 | x = self.pooling2(x) 66 | x = torch.log(x) 67 | x = self.drop(x) 68 | 69 | # FC Layer 70 | x = x.view(-1, self.num_flat_features(x)) 71 | x = self.fc1(x) 72 | return x 73 | 74 | def num_flat_features(self, x): 75 | size = x.size()[1:] # all dimensions except the batch dimension 76 | num_features = 1 77 | for s in size: 78 | num_features *= s 79 | return num_features 80 | 81 | 82 | class SelectionLayer(nn.Module): 83 | def __init__(self, N,M,temperature=1.0): 84 | 85 | super(SelectionLayer, self).__init__() 86 | self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor 87 | self.N = N 88 | self.M = M 89 | self.qz_loga = Parameter(torch.randn(N,M)/100) 90 | 91 | self.temperature=self.floatTensor([temperature]) 92 | self.freeze=False 93 | self.thresh=3.0 94 | 95 | def quantile_concrete(self, x): 96 | 97 | g = -torch.log(-torch.log(x)) 98 | y = (self.qz_loga+g)/self.temperature 99 | y = torch.softmax(y,dim=1) 100 | 101 | return y 102 | 103 | def regularization(self): 104 | 105 | eps = 1e-10 106 | z = torch.clamp(torch.softmax(self.qz_loga,dim=0),eps,1) 107 | H = torch.sum(F.relu(torch.norm(z,1,dim=1)-self.thresh)) 108 | 109 | return H 110 | 111 | def get_eps(self, size): 112 | 113 | eps = self.floatTensor(size).uniform_(epsilon, 1-epsilon) 114 | 115 | return eps 116 | 117 | def sample_z(self, batch_size, training): 118 | 119 | if training: 120 | 121 | eps = self.get_eps(self.floatTensor(batch_size, self.N, self.M)) 122 | z = self.quantile_concrete(eps) 123 | z=z.view(z.size(0),1,z.size(1),z.size(2)) 124 | 125 | return z 126 | 127 | else: 128 | 129 | ind = torch.argmax(self.qz_loga,dim=0) 130 | one_hot = self.floatTensor(np.zeros((self.N,self.M))) 131 | for j in range(self.M): 132 | one_hot[ind[j],j]=1 133 | one_hot=one_hot.view(1,1,one_hot.size(0),one_hot.size(1)) 134 | one_hot = one_hot.expand(batch_size,1,one_hot.size(2),one_hot.size(3)) 135 | 136 | return one_hot 137 | 138 | def forward(self, x): 139 | 140 | z = self.sample_z(x.size(0),training=(self.training and not self.freeze)) 141 | z_t = torch.transpose(z,2,3) 142 | out = torch.matmul(z_t,x) 143 | return out 144 | 145 | class SelectionNet(nn.Module): 146 | 147 | def __init__(self,input_dim,M,output_dim=4): 148 | super(SelectionNet,self).__init__() 149 | self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor 150 | 151 | self.N = input_dim[0] 152 | self.T = input_dim[1] 153 | self.M = M 154 | self.input_dim = input_dim 155 | self.output_dim = output_dim 156 | 157 | self.network = MSFBCNN(input_dim=[self.M,self.T],output_dim=output_dim) 158 | 159 | self.selection_layer = SelectionLayer(self.N,self.M) 160 | 161 | self.layers = self.create_layers_field() 162 | self.apply(init_weights) 163 | 164 | def forward(self,x): 165 | 166 | y_selected = self.selection_layer(x) 167 | out = self.network(y_selected) 168 | 169 | return out 170 | 171 | def regularizer(self,lamba,weight_decay): 172 | 173 | #Regularization of selection layer 174 | reg_selection=self.floatTensor([0]) 175 | #L2-Regularization of other layers 176 | reg=self.floatTensor([0]) 177 | for i,layer in enumerate(self.layers): 178 | if(type(layer) == SelectionLayer): 179 | reg_selection += layer.regularization() 180 | else: 181 | reg+=torch.sum(torch.pow(layer.weight,2)) 182 | reg = weight_decay*reg + lamba*reg_selection 183 | return reg 184 | 185 | def create_layers_field(self): 186 | layers = [] 187 | for idx, m in enumerate(self.modules()): 188 | if(type(m) == nn.Conv2d or type(m) == nn.Linear or type(m) == SelectionLayer): 189 | layers.append(m) 190 | return layers 191 | 192 | def get_num_params(self): 193 | t=0 194 | for i,layer in enumerate(self.layers): 195 | print('Layer ' + str(i)) 196 | print(layer) 197 | n=0 198 | for p in layer.parameters(): 199 | n += np.prod(np.array(p.size())) 200 | print('Amount of parameters:' + str(n)) 201 | t+=n 202 | print('Total amount of parameters ' + str(t)) 203 | return t 204 | 205 | def set_temperature(self,temp): 206 | m=self.selection_layer 207 | m.temperature=temp 208 | 209 | def set_thresh(self,thresh): 210 | m=self.selection_layer 211 | m.thresh=thresh 212 | 213 | def monitor(self): 214 | 215 | m = self.selection_layer 216 | eps = 1e-10 217 | #Probability distributions 218 | z = torch.clamp(torch.softmax(m.qz_loga,dim=0),eps,1) 219 | #Normalized entropy 220 | H = - torch.sum(z*torch.log(z),dim=0)/math.log(self.N) 221 | #Selections 222 | s = torch.argmax(m.qz_loga,dim=0)+1 223 | 224 | return H,s,z 225 | 226 | def num_flat_features(self, x): 227 | size = x.size()[1:] # all dimensions except the batch dimension 228 | num_features = 1 229 | for s in size: 230 | num_features *= s 231 | return num_features 232 | 233 | def set_freeze(self,x): 234 | 235 | m = self.selection_layer 236 | if(x): 237 | for param in m.parameters(): 238 | param.requires_grad=False 239 | m.freeze = True 240 | else: 241 | for param in m.parameters(): 242 | param.requires_grad=True 243 | m.freeze = False 244 | -------------------------------------------------------------------------------- /ChannelSelection/selectNchannels.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import models 8 | from loader import within_subject_loader_HGD, all_subject_loader_HGD 9 | from models import SelectionNet, init_weights 10 | 11 | import statistics 12 | from random import randint 13 | import importlib 14 | 15 | 16 | parser = argparse.ArgumentParser(description='PyTorch Channel Selection Training') 17 | 18 | parser.add_argument('--M',type=int,default=3, 19 | 20 | help='number of selection neurons') 21 | 22 | parser.add_argument('--epochs',type=int, default=200, 23 | 24 | help='number of total epochs to run') 25 | 26 | parser.add_argument('-b', '--batch-size', type=int, default=16, 27 | 28 | help='mini-batch size') 29 | 30 | parser.add_argument('--gradacc', type = int, default=1, 31 | 32 | help='gradient accumulation') 33 | 34 | parser.add_argument('--weight-decay', '--wd', type=float, default=5e-4, 35 | 36 | help='weight decay') 37 | 38 | parser.add_argument('--lr', '--learning-rate', type=float, default=0.001, 39 | 40 | help='initial learning rate') 41 | 42 | parser.add_argument('--lamba', type=float, default=0.1, 43 | 44 | help='regularization weight') 45 | 46 | parser.add_argument('--start_temp',type=float,default=10.0, 47 | 48 | help='initial temperature') 49 | parser.add_argument('--end_temp',type=float,default=0.1, 50 | 51 | help='final temperature') 52 | 53 | parser.add_argument('--train_split',type=float,default=0.8, 54 | 55 | help='training-validation data split') 56 | parser.add_argument('--patience', type=int, default=10, 57 | 58 | help='amount of epochs before early stopping') 59 | 60 | parser.add_argument('--stop_delta', type=float, default=1e-3, 61 | 62 | help='maximal drop in validation loss for early stopping') 63 | 64 | parser.add_argument('--entropy_lim', type=float, default=0.05, 65 | 66 | help='mean entropy for the selection neurons to be reached for convergence') 67 | 68 | parser.add_argument('--seed',type=int,default=0, 69 | 70 | help='random seed, 0 indicates randomly chosen seed') 71 | 72 | parser.add_argument('-v', action="store_true", default=True, dest="verbose") 73 | 74 | 75 | def main(): 76 | 77 | global args,enable_cuda 78 | ################################################################ INIT ################################################################################# 79 | 80 | args = parser.parse_args() 81 | 82 | cwd=os.getcwd() 83 | dpath=os.path.dirname(cwd) 84 | #Paths for data, model and checkpoint 85 | data_path = os.path.join(dpath,'Data/') 86 | model_save_path = os.path.join(dpath,'Models','Model_GumbelregHighgamma_M'+str(args.M)) 87 | checkpoint_path = os.path.join(dpath,'Models','Checkpoint_GumbelregHighgamma_M'+str(args.M)) 88 | if not os.path.isdir(os.path.join(dpath,'Models')): 89 | os.makedirs(os.path.join(dpath,'Models')) 90 | 91 | #Check if CUDA is available 92 | enable_cuda = torch.cuda.is_available() 93 | if(args.verbose): 94 | print('GPU computing: ', enable_cuda) 95 | 96 | #Set random seed 97 | if(args.seed==0): 98 | args.seed=randint(1,99999) 99 | 100 | #Initialize devices with random seed 101 | torch.manual_seed(args.seed) 102 | torch.backends.cudnn.deterministic = True 103 | torch.backends.cudnn.benchmark = False 104 | 105 | training_accs = [] 106 | val_accs=[] 107 | test_accs = [] 108 | 109 | #Create a vector of length epochs, decaying start_value to end_value exponentially, reaching end_value at end_epoch 110 | def exponential_decay_schedule(start_value,end_value,epochs,end_epoch): 111 | t = torch.FloatTensor(torch.arange(0.0,epochs)) 112 | p = torch.clamp(t/end_epoch,0,1) 113 | out = start_value*torch.pow(end_value/start_value,p) 114 | 115 | return out 116 | 117 | #Network loss function 118 | def loss_function(output,target,model,lamba,weight_decay): 119 | l = nn.CrossEntropyLoss() 120 | sup_loss = l(output,target) 121 | reg = model.regularizer(lamba,weight_decay) 122 | 123 | return sup_loss,reg 124 | 125 | #Create schedule for temperature and regularization threshold 126 | temperature_schedule = exponential_decay_schedule(args.start_temp,args.end_temp,args.epochs,int(args.epochs*3/4)) 127 | thresh_schedule = exponential_decay_schedule(3.0,1.1,args.epochs,args.epochs) 128 | 129 | #Load data 130 | num_subjects = 14 131 | input_dim=[44,1125] 132 | train_loader,val_loader,test_loader = all_subject_loader_HGD(batch_size=args.batch_size,train_split=args.train_split,path=data_path) 133 | 134 | 135 | ################################################################ SUBJECT-INDEPENDENT CHANNEL SELECTION ################################################################################# 136 | 137 | if(args.verbose): 138 | print('Start training') 139 | 140 | torch.manual_seed(args.seed) 141 | torch.backends.cudnn.deterministic = True 142 | torch.backends.cudnn.benchmark = False 143 | 144 | #Instantiate model 145 | model = SelectionNet(input_dim,args.M) 146 | if(enable_cuda): 147 | model.cuda() 148 | model.set_freeze(False) 149 | 150 | optimizer = torch.optim.Adam(model.parameters(),args.lr) 151 | 152 | prev_val_loss = 100 153 | patience_timer = 0 154 | early_stop = False 155 | epoch = 0 156 | 157 | while epoch in range(args.epochs) and (not early_stop): 158 | 159 | #Update temperature and threshold 160 | model.set_thresh(thresh_schedule[epoch]) 161 | model.set_temperature(temperature_schedule[epoch]) 162 | 163 | #Perform training step 164 | train(train_loader, model, loss_function, optimizer,epoch,args.weight_decay,args.lamba,args.gradacc,args.verbose) 165 | val_loss = validate(val_loader,model,loss_function,epoch,args.weight_decay,args.lamba,args.verbose) 166 | tr_acc,val_acc,test_acc=test(train_loader,val_loader,test_loader,model,loss_function,args.weight_decay,args.verbose) 167 | 168 | #Extract selection neuron entropies, current selections and probability distributions 169 | H,sel,probas = model.monitor() 170 | 171 | #If selection convergence is reached, enable early stopping scheme 172 | if((torch.mean(H.data)<=args.entropy_lim) and (val_loss>prev_val_loss-args.stop_delta)): 173 | patience_timer+=1 174 | if(args.verbose): 175 | print('Early stopping timer ', patience_timer) 176 | if(patience_timer == args.patience): 177 | early_stop = True 178 | else: 179 | patience_timer=0 180 | H,sel,probas = model.monitor() 181 | torch.save(model.state_dict(),checkpoint_path) 182 | prev_val_loss = val_loss 183 | 184 | 185 | epoch+=1 186 | 187 | if(args.verbose): 188 | print('Channel selection finished') 189 | 190 | #Store subject independent model 191 | model.load_state_dict(torch.load(checkpoint_path)) 192 | pretrained_path = str(model_save_path+'all_subjects_channels_selected.pt') 193 | torch.save(model.state_dict(), pretrained_path) 194 | 195 | ################################################################ SUBJECT FINETUNING ################################################################################# 196 | 197 | if(args.verbose): 198 | print('Start subject specific training') 199 | 200 | for k in range(1,num_subjects+1): 201 | 202 | 203 | if(args.verbose): 204 | print('Start training for subject ' + str(k)) 205 | 206 | torch.manual_seed(args.seed) 207 | torch.backends.cudnn.deterministic = True 208 | torch.backends.cudnn.benchmark = False 209 | 210 | #Load subject independent model and freeze selection neurons 211 | model = SelectionNet(input_dim,args.M) 212 | model.load_state_dict(torch.load(pretrained_path)) 213 | if(enable_cuda): 214 | model.cuda() 215 | model.set_freeze(True) 216 | 217 | #Load subject dependent data 218 | train_loader,val_loader,test_loader = within_subject_loader_HGD(subject=k,batch_size=args.batch_size,train_split=args.train_split,path=data_path) 219 | 220 | optimizer = torch.optim.Adam(model.parameters(),args.lr) 221 | 222 | prev_val_loss = 100 223 | patience_timer = 0 224 | early_stop = False 225 | epoch = 0 226 | while epoch in range(args.epochs) and (not early_stop): 227 | 228 | #Perform train step 229 | train(train_loader, model, loss_function, optimizer,epoch,args.weight_decay,args.lamba,args.gradacc,args.verbose) 230 | val_loss = validate(val_loader,model,loss_function,epoch,args.weight_decay,args.lamba,args.verbose) 231 | tr_acc,val_acc,test_acc=test(train_loader,val_loader,test_loader,model,loss_function,args.weight_decay,args.verbose) 232 | 233 | #Extract selection neuron entropies, current selections and probability distributions 234 | H,sel,probas = model.monitor() 235 | 236 | #Perform early stopping 237 | if(val_loss>prev_val_loss-args.stop_delta): 238 | patience_timer+=1 239 | if(args.verbose): 240 | print('Early stopping timer ', patience_timer) 241 | if(patience_timer == args.patience): 242 | early_stop = True 243 | else: 244 | patience_timer=0 245 | torch.save(model.state_dict(),checkpoint_path) 246 | prev_val_loss = val_loss 247 | 248 | epoch+=1 249 | 250 | 251 | #Store model with lowest validation loss 252 | model.load_state_dict(torch.load(checkpoint_path)) 253 | path = str(model_save_path+'finished_subject'+str(k)+'.pt') 254 | torch.save(model.state_dict(), path) 255 | 256 | #Evaluate model 257 | tr_acc,val_acc,test_acc = test(train_loader,val_loader,test_loader,model,loss_function,args.weight_decay,args.verbose) 258 | training_accs.append(tr_acc) 259 | val_accs.append(val_acc) 260 | test_accs.append(test_acc) 261 | 262 | ################################################################ TERMINATION ################################################################################# 263 | 264 | print('Selection', sel.data) 265 | print('Training accuracies', training_accs) 266 | print('Validation accuracies', val_accs) 267 | print('Testing accuracies', test_accs) 268 | 269 | tr_med = statistics.median(training_accs) 270 | val_med = statistics.median(val_accs) 271 | test_med = statistics.median(test_accs) 272 | tr_mean = statistics.mean(training_accs) 273 | val_mean = statistics.mean(val_accs) 274 | test_mean = statistics.mean(test_accs) 275 | 276 | print('Training median accuracy', tr_med) 277 | print('Validation median accuracy', val_med) 278 | print('Testing median accuracy', test_med) 279 | print('Training mean accuracy', tr_mean) 280 | print('Validation mean accuracy', val_mean) 281 | print('Testing mean accuracy', test_mean) 282 | 283 | #train 1 epoch 284 | def train(train_loader, model, loss_function, optimizer, epoch, weight_decay,lamba,gradacc,verbose): 285 | 286 | global running_loss, running_sup_loss, running_reg, running_acc,enable_cuda 287 | 288 | model.train() 289 | 290 | for i, (data, labels) in enumerate(train_loader): 291 | 292 | if(enable_cuda): 293 | data= data.cuda() 294 | labels = labels.cuda() 295 | 296 | if(i==0): 297 | running_loss = 0.0 298 | running_reg = 0.0 299 | running_sup_loss = 0.0 300 | running_acc = np.array([0,0]) 301 | 302 | output = model(data) 303 | 304 | sup_loss,reg = loss_function(output, labels, model,lamba,weight_decay) 305 | loss = sup_loss + reg 306 | loss=loss/gradacc 307 | 308 | loss.backward() 309 | 310 | #Perform gradient accumulation 311 | if((i+1)%gradacc ==0): 312 | optimizer.step() 313 | optimizer.zero_grad() 314 | 315 | #running accuracy 316 | score, predicted = torch.max(output,1) 317 | total = predicted.size(0) 318 | correct = (predicted == labels).sum().item() 319 | running_acc = np.add(running_acc, np.array([correct,total])) 320 | 321 | # print statistics 322 | running_loss += loss.item() 323 | running_reg += reg.item() 324 | running_sup_loss += sup_loss.item() 325 | N = len(train_loader) 326 | if(i==N-1): 327 | if(verbose): 328 | print('[%d, %5d] loss: %.3f acc: %d %% supervised loss: %.3f regularization loss %.3f'% 329 | (epoch + 1, i + 1, running_loss / N, 100*running_acc[0]/running_acc[1], running_sup_loss/N, running_reg/N)) 330 | running_loss = 0.0 331 | running_reg = 0.0 332 | running_sup_loss = 0.0 333 | running_acc = (0,0) 334 | 335 | def validate(val_loader,model,loss_function,epoch,weight_decay,lamba,verbose): 336 | 337 | global val_acc,val_loss,enable_cuda 338 | 339 | with torch.no_grad(): 340 | model.eval() 341 | 342 | for i, (data, labels) in enumerate(val_loader): 343 | 344 | if(enable_cuda): 345 | data= data.cuda() 346 | labels = labels.cuda() 347 | 348 | if(i==0): 349 | val_loss = 0.0 350 | val_acc = np.array([0,0]) 351 | 352 | output = model(data) 353 | sup_loss,reg = loss_function(output, labels, model,lamba,weight_decay) 354 | loss = sup_loss 355 | 356 | #running accuracy 357 | score, predicted = torch.max(output,1) 358 | total = predicted.size(0) 359 | correct = (predicted == labels).sum().item() 360 | val_acc = np.add(val_acc, np.array([correct,total])) 361 | 362 | # print statistics 363 | val_loss += loss.item() 364 | N = len(val_loader) 365 | if(i == N-1): 366 | if(verbose): 367 | print('[%d, %5d] Validation loss: %.3f Validation accuracy: %d %%'% 368 | (epoch + 1, i + 1, val_loss / N,100*val_acc[0]/val_acc[1] )) 369 | 370 | return val_loss/N 371 | 372 | def test(train_loader,val_loader,test_loader, model,loss_function,weight_decay,verbose): 373 | 374 | global enable_cuda 375 | 376 | with torch.no_grad(): 377 | 378 | model.train() 379 | 380 | total = 0 381 | correct = 0 382 | 383 | for i, (data, labels) in enumerate(train_loader): 384 | 385 | if(enable_cuda): 386 | data= data.cuda() 387 | labels = labels.cuda() 388 | 389 | output = model(data) 390 | score, predicted = torch.max(output,1) 391 | total += predicted.size(0) 392 | correct += (predicted == labels).sum().item() 393 | 394 | tr_acc = correct/total 395 | 396 | if(verbose): 397 | print('Training accuracy: %d %%' % (100 * tr_acc)) 398 | 399 | model.eval() 400 | 401 | total = 0 402 | correct = 0 403 | 404 | for i, (data, labels) in enumerate(val_loader): 405 | 406 | if(enable_cuda): 407 | data= data.cuda() 408 | labels = labels.cuda() 409 | 410 | output = model(data) 411 | score, predicted = torch.max(output,1) 412 | total += predicted.size(0) 413 | correct += (predicted == labels).sum().item() 414 | 415 | val_acc = correct/total 416 | 417 | if(verbose): 418 | print('Validation accuracy: %d %%' % (100 * val_acc)) 419 | 420 | total=0 421 | correct=0 422 | 423 | for i, (data, labels) in enumerate(test_loader): 424 | 425 | if(enable_cuda): 426 | data= data.cuda() 427 | labels = labels.cuda() 428 | 429 | output = model(data) 430 | score, predicted = torch.max(output,1) 431 | total += predicted.size(0) 432 | correct += (predicted == labels).sum().item() 433 | 434 | test_acc = correct/total 435 | 436 | if(verbose): 437 | print('Test accuracy: %d %%' % (100 * test_acc)) 438 | 439 | return tr_acc,val_acc,test_acc 440 | 441 | if __name__ == '__main__': 442 | 443 | main() 444 | --------------------------------------------------------------------------------