├── .gitignore ├── EEG-Inception ├── EEG_Inception.py ├── Net_train.py ├── Readme.md └── cross_val_net.py ├── EEGNet ├── EEGNet_paper.pth ├── EEGNet_version_01.py ├── Readme.md ├── Test_model.py └── Test_net.py ├── ML-MI ├── BCICIV_calib_ds1b.mat ├── EEG_MI_ML.ipynb ├── Readme.md └── refresh_classification.ipynb ├── New train folder ├── EEGNet.py ├── Readme.md ├── confusion_matrix1.png ├── metrics.py ├── read_data.ipynb ├── train.py └── training_metrics_subject_1.png ├── README.md ├── mne包教程.pdf └── 脑机接口导论笔记.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /EEG-Inception/EEG_Inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchsummary import summary 5 | 6 | 7 | class DepthwiseSeparableConv2d(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, padding=0): 9 | super(DepthwiseSeparableConv2d, self).__init__() 10 | self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding='valid', groups=in_channels) 11 | self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1) 12 | 13 | def forward(self, x): 14 | x = self.depthwise(x) 15 | x = self.pointwise(x) 16 | # 进行通道压缩 17 | 18 | 19 | return x 20 | 21 | class EEGInception(nn.Module): 22 | def __init__(self, input_time=1000, fs=128, ncha=8, filters_per_branch=8, 23 | scales_time=(500, 250, 125), dropout_rate=0.25, 24 | activation='relu', n_classes=2): 25 | super(EEGInception, self).__init__() 26 | 27 | # ============================= CALCULATIONS ============================= # 28 | input_samples = int(input_time * fs / 1000) 29 | scales_samples = [int(s * fs / 1000) for s in scales_time] 30 | 31 | # ================================ INPUT ================================= # 32 | self.input_layer = nn.Conv2d(1, ncha, kernel_size=(1, 1)) 33 | 34 | # ========================== BLOCK 1: INCEPTION ========================== # 35 | b1_units = [] 36 | for i in range(len(scales_samples)): 37 | unit = nn.Sequential( 38 | nn.Conv2d(ncha, ncha, kernel_size=(1, scales_samples[i]), padding='same'), 39 | nn.BatchNorm2d(ncha), 40 | nn.ELU(inplace=True), 41 | DepthwiseSeparableConv2d(ncha, ncha*2, kernel_size=(ncha, 1)), 42 | nn.BatchNorm2d(ncha*2), 43 | nn.ELU(inplace=True), 44 | nn.Dropout(dropout_rate) 45 | ) 46 | b1_units.append(unit) 47 | 48 | self.b1_units = nn.ModuleList(b1_units) 49 | 50 | # ========================== BLOCK 2: INCEPTION ========================== # 51 | b2_units = [] 52 | for i in range(len(scales_samples)): 53 | unit = nn.Sequential( 54 | nn.Conv2d(filters_per_branch*6, filters_per_branch, kernel_size=(int(scales_samples[i]/4), 1), padding='same', padding_mode='zeros'), 55 | nn.BatchNorm2d(filters_per_branch), 56 | nn.ELU(inplace=True), 57 | nn.Dropout(dropout_rate) 58 | ) 59 | b2_units.append(unit) 60 | 61 | self.b2_units = nn.ModuleList(b2_units) 62 | 63 | # ============================ BLOCK 3: OUTPUT =========================== # 64 | self.b3_u1 = nn.Sequential( 65 | nn.Conv2d(filters_per_branch * len(scales_samples), int(filters_per_branch*len(scales_samples)/2), kernel_size=(8, 1),padding='same'), 66 | nn.BatchNorm2d(int(filters_per_branch*len(scales_samples)/2)), 67 | nn.ELU(inplace=True), 68 | nn.AvgPool2d((2, 1)), 69 | nn.Dropout(dropout_rate) 70 | ) 71 | 72 | self.b3_u2 = nn.Sequential( 73 | nn.Conv2d(int(filters_per_branch*len(scales_samples)/2), int(filters_per_branch*len(scales_samples)/4), kernel_size=(4, 1),padding='same'), 74 | nn.BatchNorm2d(int(filters_per_branch*len(scales_samples)/4)), 75 | nn.ELU(inplace=True), 76 | nn.AvgPool2d((2, 1)), 77 | nn.Dropout(dropout_rate) 78 | ) 79 | 80 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 81 | self.fc = nn.Linear(int(filters_per_branch*len(scales_samples)/4), n_classes) 82 | 83 | def forward(self, x): 84 | # ================================ INPUT ================================= # 85 | 86 | 87 | x = self.input_layer(x) 88 | 89 | 90 | 91 | # ========================== BLOCK 1: INCEPTION ========================== # 92 | b1_outputs = [unit(x) for unit in self.b1_units] 93 | 94 | 95 | b1_out = torch.cat(b1_outputs, dim=1) 96 | 97 | b1_out = b1_out.permute((0, 1, 3, 2)) 98 | 99 | b1_out = F.avg_pool2d(b1_out, (4, 1)) 100 | # b1_out = b1_out.permute((0, 2, 1, 3)) 101 | 102 | 103 | 104 | 105 | # ========================== BLOCK 2: INCEPTION ========================== # 106 | b2_outputs = [unit(b1_out) for unit in self.b2_units] 107 | 108 | b2_out = torch.cat(b2_outputs, dim=1) 109 | 110 | b2_out = F.avg_pool2d(b2_out, (2, 1)) 111 | 112 | 113 | # ============================ BLOCK 3: OUTPUT =========================== # 114 | b3_u1_out = F.avg_pool2d(F.elu(self.b3_u1(b2_out)), (2, 1)) 115 | 116 | b3_u2_out = F.avg_pool2d(F.elu(self.b3_u2(b3_u1_out)), (2, 1)) 117 | 118 | b3_out = self.avgpool(b3_u2_out) 119 | 120 | b3_out = b3_out.view(b3_out.size(0), -1) 121 | output = self.fc(b3_out) 122 | return output 123 | 124 | 125 | 126 | 127 | 128 | if __name__ == '__main__': 129 | data = torch.randn(1, 1, 8, 128).to('cuda') 130 | model = EEGInception().to('cuda') 131 | output = model(data) 132 | sum_parameter = 0 133 | for param in model.parameters(): 134 | sum_parameter += param.numel() 135 | print(sum_parameter) 136 | summary(model, (1, 8, 128), device='cuda', batch_size=48) 137 | -------------------------------------------------------------------------------- /EEG-Inception/Net_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from EEGNet_pytorch_version import * 5 | from torch import optim 6 | from EEG_Inception import * 7 | 8 | from sklearn.model_selection import cross_val_score, StratifiedKFold 9 | 10 | 11 | 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | torch.backends.cudnn.enabled = True 15 | # 加载EEGNetdata 16 | 17 | EEGNetdata = EEGNetDataset(file_path='C:\\Users\\24242\\PycharmProjects\\paper_rebuild\\EEG-dataprocessing\\EEG-Conformerprocessing\\combine_data_and_label\\A01_combine\\train_data_A01.pt', target_path='C:\\Users\\24242\\PycharmProjects\\paper_rebuild\\EEG-dataprocessing\\EEG-Conformerprocessing\\combine_data_and_label\\A01_combine\\train_label_A01.pt', transform=False, target_transform=False) 18 | 19 | train_dataloader = DataLoader(EEGNetdata, batch_size=48, shuffle=False) 20 | 21 | # 构建EEGNet 22 | print(device) 23 | 24 | net = EEGInception(input_time=1000, fs=250, ncha=22, filters_per_branch=22, n_classes=4).to(device) 25 | 26 | 27 | 28 | # 损失函数 29 | criterion = nn.CrossEntropyLoss() 30 | 31 | optimizer = optim.Adam(net.parameters(), lr=0.001) 32 | counter = [] 33 | # 画图要用 34 | loss_history = [] 35 | 36 | iteration_number = 0 37 | train_correct = 0 38 | total = 0 39 | 40 | classNum = 4 41 | # 画图要用 42 | acc_history = [] 43 | # 开启训练模式 44 | net.train() 45 | 46 | for epoch in range(0, 250): 47 | for i,data in enumerate(train_dataloader, 0): 48 | item, target = data 49 | item, target = item.to(device), target.to(device) 50 | item = item.type(torch.cuda.FloatTensor) 51 | target = target.type(torch.cuda.LongTensor) 52 | 53 | optimizer.zero_grad() 54 | 55 | output = net(item) 56 | 57 | loss = criterion(output, target) 58 | loss.backward() 59 | optimizer.step() 60 | 61 | pred = torch.max(output.data, 1)[1] 62 | train_correct += (pred == target).sum().item() 63 | total += target.size(0) 64 | train_acc = train_correct / total 65 | train_acc = np.array(train_acc) 66 | if i % 50 == 0: 67 | print('Epoch number {}\n acc {}\n loss {}'.format(epoch, train_acc, loss)) 68 | 69 | iteration_number += 1 70 | counter.append(iteration_number) 71 | acc_history.append(train_acc.item()) 72 | loss_history.append(loss.item()) 73 | 74 | 75 | show_plot(counter, acc_history, loss_history) 76 | torch.save(net, 'EEGNet_paper.pth') 77 | 78 | 79 | -------------------------------------------------------------------------------- /EEG-Inception/Readme.md: -------------------------------------------------------------------------------- 1 | This directory contain EEG-Inception implements and myself EEG_data_processing file 2 | It come from my blog,See my blog for more details https://blog.csdn.net/frankprok?type=blog 3 | -------------------------------------------------------------------------------- /EEG-Inception/cross_val_net.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import cross_val_score,StratifiedKFold 2 | from sklearn.metrics import accuracy_score 3 | import torch 4 | import torch.nn as nn 5 | from torch import optim 6 | from torch.utils.data import DataLoader, TensorDataset 7 | from EEGNet_pytorch_version import * 8 | from EEG_Inception import * 9 | import numpy as np 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | data = torch.load('C:\\Users\\24242\\PycharmProjects\\paper_rebuild\\EEG-dataprocessing\\EEG-Conformerprocessing\\combine_data_and_label\\A03_combine\\A03_combine_data.pt') 14 | label = torch.load('C:\\Users\\24242\\PycharmProjects\\paper_rebuild\\EEG-dataprocessing\\EEG-Conformerprocessing\\combine_data_and_label\\A03_combine\\A03_combine_label.pt') 15 | data = data.detach().to('cpu') 16 | label = label.detach().to('cpu') 17 | cv = StratifiedKFold(n_splits=8) 18 | 19 | acc = [] 20 | count = 0 21 | for train_index, test_index in cv.split(data, label): 22 | count += 1 23 | X_train, X_test = data[train_index], data[test_index] 24 | y_train, y_test = label[train_index], label[test_index] 25 | train_dataset = TensorDataset(X_train, y_train) 26 | batch_size = 48 27 | test_dataset = TensorDataset(X_test, y_test) 28 | train_loader = DataLoader(train_dataset, batch_size=batch_size) 29 | test_loader = DataLoader(test_dataset, batch_size=batch_size) 30 | 31 | model = EEGInception(input_time=1000, fs=250, ncha=22, filters_per_branch=22, n_classes=4).to(device) 32 | criterion = nn.CrossEntropyLoss() 33 | optimizer = optim.Adam(model.parameters(), lr=0.001) 34 | model.train() 35 | 36 | 37 | print(f'-------------The {count} number run the model------------------') 38 | 39 | for epoch in range(500): 40 | 41 | for batch_data, batch_labels in train_loader: 42 | batch_data = batch_data.to(device).type(torch.cuda.FloatTensor) 43 | batch_labels = batch_labels.to(device).type(torch.cuda.LongTensor) 44 | optimizer.zero_grad() 45 | output = model(batch_data) 46 | loss = criterion(output, batch_labels) 47 | loss.backward() 48 | optimizer.step() 49 | if epoch % 10 == 0: 50 | print("Epoch: {}/{}.. ".format(epoch + 1, 500)) 51 | print("Loss: {:.12f}".format(loss.item())) 52 | 53 | 54 | model.eval() 55 | all_pre = [] 56 | all_label = [] 57 | with torch.no_grad(): 58 | for batch_data, batch_labels in test_loader: 59 | batch_data = batch_data.to(device).type(torch.cuda.FloatTensor) 60 | batch_labels = batch_labels.to(device).type(torch.cuda.LongTensor) 61 | output = model(batch_data) 62 | pre = torch.max(output.data, 1)[1] 63 | pre = pre.cpu().numpy() 64 | all_pre.extend(pre) 65 | all_label.extend(batch_labels.cpu().numpy()) 66 | 67 | acc_num = accuracy_score(y_true=np.array(all_label), y_pred=np.array(all_pre)) 68 | acc.append(acc_num) 69 | print(f'-------------Finish {count} acc recognize-------------------') 70 | print(f'-------------Test Accuracy: {acc_num}-----------------------') 71 | 72 | 73 | print('-------------Finish process-------------') 74 | print(f'cross_val_scores is{np.mean(acc)}') 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /EEGNet/EEGNet_paper.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XCZchaos/python-implementation-of-motion-imagination-classification/238c22858402277b51a80c7d9f6d03c17a9f5aa0/EEGNet/EEGNet_paper.pth -------------------------------------------------------------------------------- /EEGNet/EEGNet_version_01.py: -------------------------------------------------------------------------------- 1 | # 导入工具包 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class EEGNet(nn.Module): 9 | def __init__(self, classes_num): 10 | super(EEGNet, self).__init__() 11 | self.drop_out = 0.25 12 | 13 | self.block_1 = nn.Sequential( 14 | # Pads the input tensor boundaries with zero 15 | # left, right, up, bottom 16 | nn.ZeroPad2d((31, 32, 0, 0)), 17 | nn.Conv2d( 18 | in_channels=1, # input shape (1, C, T) 19 | out_channels=8, # num_filters 20 | kernel_size=(1, 64), # filter size 21 | bias=False 22 | ), # output shape (8, C, T) 23 | nn.BatchNorm2d(8) # output shape (8, C, T) 24 | ) 25 | 26 | # block 2 and 3 are implementations of Depthwise Convolution and Separable Convolution 27 | self.block_2 = nn.Sequential( 28 | nn.Conv2d( 29 | in_channels=8, # input shape (8, C, T) 30 | out_channels=16, # num_filters 31 | kernel_size=(22, 1), # filter size 32 | # group8意味着八组滤波器 33 | groups=8, 34 | bias=False 35 | ), # output shape (16, 1, T) 36 | nn.BatchNorm2d(16), # output shape (16, 1, T) 37 | nn.ELU(), 38 | nn.AvgPool2d((1, 4)), # output shape (16, 1, T//4) 39 | nn.Dropout(self.drop_out) # output shape (16, 1, T//4) 40 | ) 41 | 42 | self.block_3 = nn.Sequential( 43 | nn.ZeroPad2d((8, 8, 0, 0)), 44 | nn.Conv2d( 45 | in_channels=16, # input shape (16, 1, T//4) 46 | out_channels=16, # num_filters 47 | kernel_size=(1, 16), # filter size 48 | # 十六组滤波器 49 | groups=16, 50 | bias=False 51 | ), # output shape (16, 1, T//4) 52 | nn.Conv2d( 53 | in_channels=16, # input shape (16, 1, T//4) 54 | out_channels=16, # num_filters 55 | kernel_size=(1, 1), # filter size 56 | bias=False 57 | ), # output shape (16, 1, T//4) 58 | nn.BatchNorm2d(16), # output shape (16, 1, T//4) 59 | nn.ELU(), 60 | nn.AvgPool2d((1, 8)), # output shape (16, 1, T//32) 61 | nn.Dropout(self.drop_out) 62 | ) 63 | 64 | self.out = nn.Linear((16 * 31), classes_num) 65 | 66 | def forward(self, x): 67 | x = self.block_1(x) 68 | # print("block1", x.shape) 69 | x = self.block_2(x) 70 | # print("block2", x.shape) 71 | x = self.block_3(x) 72 | # print("block3", x.shape) 73 | 74 | x = x.view(x.size(0), -1) 75 | x = self.out(x) 76 | # return F.softmax(x, dim=1), x # return x for visualization 77 | return x 78 | # if __name__ == '__main__': 79 | # input = torch.randn(32,1,22,1125) 80 | # 81 | # model = EEGNet(4) 82 | # 83 | # out = model(input) 84 | # 85 | # print(model) 86 | -------------------------------------------------------------------------------- /EEGNet/Readme.md: -------------------------------------------------------------------------------- 1 | This directory contain EEGNet-pytorch version implements,and EEGNet paper code 2 | -------------------------------------------------------------------------------- /EEGNet/Test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from sklearn.metrics import accuracy_score 5 | from torch.utils.data import Dataset 6 | from torch.utils.data import DataLoader 7 | from torchsummary import summary 8 | 9 | 10 | def accuracy(output, target): 11 | pred = torch.argmax(output, dim=1) 12 | pred = pred.float() 13 | correct = torch.sum(pred == target) 14 | return 100 * correct / len(target) 15 | 16 | 17 | 18 | def plot_loss(epoch_number, loss): 19 | plt.plot(epoch_number, loss, color='red') 20 | plt.xlabel('Epoch') 21 | plt.ylabel('Loss') 22 | plt.title('Loss during test') 23 | plt.savefig("loss.jpg") 24 | plt.show() 25 | 26 | 27 | 28 | def plot_accuracy(epoch_number, accuracy): 29 | plt.plot(epoch_number, accuracy, color='orange') 30 | plt.xlabel('Epoch') 31 | plt.ylabel('Accuracy') 32 | plt.title('Accuracy during test') 33 | plt.savefig("accuracy.jpg") 34 | plt.show() 35 | 36 | 37 | 38 | def plot_recall(epoch_number, recall): 39 | plt.plot(epoch_number, recall, color='purple', label='Recall') 40 | plt.xlabel('Epoch') 41 | plt.ylabel('Rate') 42 | plt.title('Recall during test') 43 | plt.savefig("recall.jpg") 44 | plt.show() 45 | 46 | 47 | 48 | def plot_precision(epoch_number, precision): 49 | plt.plot(epoch_number, precision, color='black', label='Precision') 50 | plt.xlabel('Epoch') 51 | plt.ylabel('Rate') 52 | plt.title('Precision during test') 53 | plt.savefig("precision.jpg") 54 | plt.show() 55 | 56 | 57 | 58 | def plot_f1(epoch_number, f1): 59 | plt.plot(epoch_number, f1, color='yellow', label='f1') 60 | plt.xlabel('Epoch') 61 | plt.ylabel('Rate') 62 | plt.title('f1 during test') 63 | plt.savefig("f1.jpg") 64 | plt.show() 65 | 66 | 67 | 68 | def calc_recall_precision(output, target): 69 | pred = torch.argmax(output, dim=1) 70 | pred = pred.float() 71 | tp = ((pred == target) & (target == 1)).sum().item() # 正确预测为“相同”的样本数 72 | tn = ((pred == target) & (target == 0)).sum().item() # 正确预测为“不相同”的样本数 73 | fp = ((pred != target) & (target == 0)).sum().item() # 错误预测为“相同”的样本数 74 | fn = ((pred != target) & (target == 1)).sum().item() # 错误预测为“不相同”的样本数 75 | recall = tp / (tp + fn) if (tp + fn) != 0 else 0 # 计算召回率 76 | precision = tp / (tp + fp) if (tp + fp) != 0 else 0 # 计算精确度 77 | return recall, precision 78 | 79 | 80 | 81 | # ## 用于配置的帮助类 82 | class Config(): 83 | training_dir = "./data/faces/training/" 84 | testing_dir = "./data/faces/testing/" 85 | # batch_size也会影响模型的精度 86 | train_batch_size = 48 # 64 87 | test_batch_size = 48 88 | train_number_epochs = 100 # 100 89 | test_number_epochs = 20 90 | 91 | 92 | 93 | class EEGNetDataset(Dataset): 94 | # Dataset模块提供了一些接口可供实现 属于是抽象基类 95 | def __init__(self,file_path,transform=None): 96 | self.file_path = file_path 97 | 98 | # 读取文件 EEGdata与label 99 | self.data = self.parse_data_file(file_path) 100 | 101 | 102 | self.transform = transform 103 | 104 | 105 | def parse_data_file(self,file_path): 106 | 107 | data = torch.load(file_path) 108 | return np.array(data,dtype=np.float32) 109 | 110 | 111 | # dataset的抽象方法 需要自己实现,下同 112 | # 返回data的长度 size为样本量总和 22*20*50 即channels*sample -> channels * h * w 113 | def __len__(self): 114 | 115 | return len(self.data) 116 | # dataset的抽象方法 117 | # 加载数据特征的index进行截取 index参数是由getitem自动生成的 118 | def __getitem__(self,index): 119 | # 只要创建了对象就会迭代 迭代48次也就是一个batch_size 120 | # data 已变成 287*22*1000的数据 121 | # 选择第一个维度index个样本 每个样本的shape为(22,20,50) 每一个index即为一个trail 122 | item = self.data[index,:] 123 | 124 | 125 | # 目前不会执行 126 | if self.transform: 127 | item = self.transform(item) 128 | 129 | 130 | return item 131 | 132 | if __name__ == '__main__': 133 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 134 | print(device) 135 | model_EEG_net = torch.load('EEGNet_paper.pth', map_location=torch.device(device)) 136 | data_test = torch.load('A01T_new.pt') 137 | print(data_test.shape) 138 | labels = torch.load('A01T_new_label.pt').to(device) 139 | 140 | 141 | 142 | model_EEG_net.eval() 143 | test_label = [] 144 | 145 | with torch.no_grad(): 146 | # EEGnetdata = EEGNetDataset(file_path ='C:\\Users\\24242\\DataspellProjects\\EEG_Project\\EEGNet\\A01T.pt',transform=False) 147 | # test_dataloader = DataLoader(EEGnetdata,shuffle=False,num_workers=0,batch_size=Config.train_batch_size,drop_last=True) 148 | 149 | 150 | 151 | 152 | output = model_EEG_net(data_test.to(device)) #输出 153 | # output = torch.max(output, 1)[1] 154 | 155 | # labels = labels[:-1] 156 | # print(label_test) 157 | # print(labels) 158 | 159 | # result = accuracy(output,labels) 160 | # result1 = accuracy(output=output, target=labels.to(device)) 161 | # 162 | # print(result1) 163 | # summary(model_EEG_net,input_size=(1, 22, 1000), batch_size=20) 164 | # print(result1) 165 | # 模型打印 166 | summary(model_EEG_net, batch_size=48, input_size=(1, 22, 1000)) 167 | 168 | # 0.89 0.85 0.71 0.86 0.87 169 | print(accuracy(output, labels)) 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /EEGNet/Test_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from EEGNet_version_01 import * 4 | 5 | 6 | num = torch.rand((32, 1, 22, 1000)) 7 | 8 | net = EEGNet(4) 9 | 10 | out = net(num) 11 | 12 | print(net) 13 | 14 | -------------------------------------------------------------------------------- /ML-MI/BCICIV_calib_ds1b.mat: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /ML-MI/Readme.md: -------------------------------------------------------------------------------- 1 | This repository corresponds to the code for my blog 2 | 3 | 4 | In this experiment, CSP and LDA were used for simple classification of EEG data. It is worth noting that at this time, EEG data were not preprocessed and denoised, nor were the subsequent models optimized and adjusted. It was only a simple experiment 5 | 6 | 7 | This experiment adopts the MNE and sklearn modules in python 8 | 9 | 10 | 11 | 12 | One of the files named EEG_MI_ML is a simple experiment of motion imaging left-handed and right-handed binary classification. The experiment involves filtering processing, CSP and LDA, and the mat files related to the experiment are also uploaded 13 | 14 | -------------------------------------------------------------------------------- /ML-MI/refresh_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 225, 6 | "id": "initial_id", 7 | "metadata": { 8 | "collapsed": true, 9 | "ExecuteTime": { 10 | "end_time": "2024-01-26T08:57:44.162933600Z", 11 | "start_time": "2024-01-26T08:57:44.115319400Z" 12 | } 13 | }, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import mne\n", 19 | "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", 20 | "from sklearn.metrics import accuracy_score,roc_curve,auc \n", 21 | "from make_mymodel.csp import CSP\n", 22 | "import os\n", 23 | "from sklearn.svm import SVC" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 226, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "['S100R01.edf', 'S100R02.edf', 'S100R03.edf', 'S100R04.edf', 'S100R05.edf', 'S100R06.edf', 'S100R07.edf', 'S100R08.edf', 'S100R09.edf', 'S100R10.edf', 'S100R11.edf', 'S100R12.edf', 'S100R13.edf', 'S100R14.edf']\n", 35 | "['S101R01.edf', 'S101R02.edf', 'S101R03.edf', 'S101R04.edf', 'S101R05.edf', 'S101R06.edf', 'S101R07.edf', 'S101R08.edf', 'S101R09.edf', 'S101R10.edf', 'S101R11.edf', 'S101R12.edf', 'S101R13.edf', 'S101R14.edf']\n", 36 | "['S102R01.edf', 'S102R02.edf', 'S102R03.edf', 'S102R04.edf', 'S102R05.edf', 'S102R06.edf', 'S102R07.edf', 'S102R08.edf', 'S102R09.edf', 'S102R10.edf', 'S102R11.edf', 'S102R12.edf', 'S102R13.edf', 'S102R14.edf']\n", 37 | "['S103R01.edf', 'S103R02.edf', 'S103R03.edf', 'S103R04.edf', 'S103R05.edf', 'S103R06.edf', 'S103R07.edf', 'S103R08.edf', 'S103R09.edf', 'S103R10.edf', 'S103R11.edf', 'S103R12.edf', 'S103R13.edf', 'S103R14.edf']\n", 38 | "['S104R01.edf', 'S104R02.edf', 'S104R03.edf', 'S104R04.edf', 'S104R05.edf', 'S104R06.edf', 'S104R07.edf', 'S104R08.edf', 'S104R09.edf', 'S104R10.edf', 'S104R11.edf', 'S104R12.edf', 'S104R13.edf', 'S104R14.edf']\n", 39 | "['S105R01.edf', 'S105R02.edf', 'S105R03.edf', 'S105R04.edf', 'S105R05.edf', 'S105R06.edf', 'S105R07.edf', 'S105R08.edf', 'S105R09.edf', 'S105R10.edf', 'S105R11.edf', 'S105R12.edf', 'S105R13.edf', 'S105R14.edf']\n", 40 | "['S106R01.edf', 'S106R02.edf', 'S106R03.edf', 'S106R04.edf', 'S106R05.edf', 'S106R06.edf', 'S106R07.edf', 'S106R08.edf', 'S106R09.edf', 'S106R10.edf', 'S106R11.edf', 'S106R12.edf', 'S106R13.edf', 'S106R14.edf']\n", 41 | "['S107R01.edf', 'S107R02.edf', 'S107R03.edf', 'S107R04.edf', 'S107R05.edf', 'S107R06.edf', 'S107R07.edf', 'S107R08.edf', 'S107R09.edf', 'S107R10.edf', 'S107R11.edf', 'S107R12.edf', 'S107R13.edf', 'S107R14.edf']\n", 42 | "['S108R01.edf', 'S108R02.edf', 'S108R03.edf', 'S108R04.edf', 'S108R05.edf', 'S108R06.edf', 'S108R07.edf', 'S108R08.edf', 'S108R09.edf', 'S108R10.edf', 'S108R11.edf', 'S108R12.edf', 'S108R13.edf', 'S108R14.edf']\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "for i in range(100,109):\n", 48 | " dir_str = r'C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S'+ str(i)\n", 49 | " file_name = os.listdir(dir_str)\n", 50 | " filter_files = [file for file in file_name if file.endswith('.edf')]\n", 51 | " print(filter_files)" 52 | ], 53 | "metadata": { 54 | "collapsed": false, 55 | "ExecuteTime": { 56 | "end_time": "2024-01-26T08:57:44.225901900Z", 57 | "start_time": "2024-01-26T08:57:44.134556700Z" 58 | } 59 | }, 60 | "id": "bbafb7fa04094e48" 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 227, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R03.edf...\n", 71 | "EDF file detected\n", 72 | "Setting channel info structure...\n", 73 | "Creating raw.info structure...\n", 74 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 75 | "Filtering raw data in 1 contiguous segment\n", 76 | "Setting up band-pass filter from 8 - 15 Hz\n", 77 | "\n", 78 | "FIR filter parameters\n", 79 | "---------------------\n", 80 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 81 | "- Windowed time-domain design (firwin) method\n", 82 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 83 | "- Lower passband edge: 8.00\n", 84 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 85 | "- Upper passband edge: 15.00 Hz\n", 86 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n", 87 | "- Filter length: 265 samples (1.656 s)\n", 88 | "\n", 89 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R04.edf...\n", 90 | "EDF file detected\n", 91 | "Setting channel info structure...\n", 92 | "Creating raw.info structure...\n", 93 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 94 | "Filtering raw data in 1 contiguous segment\n", 95 | "Setting up band-pass filter from 8 - 15 Hz\n", 96 | "\n", 97 | "FIR filter parameters\n", 98 | "---------------------\n", 99 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 100 | "- Windowed time-domain design (firwin) method\n", 101 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 102 | "- Lower passband edge: 8.00\n", 103 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 104 | "- Upper passband edge: 15.00 Hz\n", 105 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n", 106 | "- Filter length: 265 samples (1.656 s)\n", 107 | "\n", 108 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R05.edf...\n", 109 | "EDF file detected\n", 110 | "Setting channel info structure...\n", 111 | "Creating raw.info structure...\n", 112 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 113 | "Filtering raw data in 1 contiguous segment\n", 114 | "Setting up band-pass filter from 8 - 15 Hz\n", 115 | "\n", 116 | "FIR filter parameters\n", 117 | "---------------------\n", 118 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 119 | "- Windowed time-domain design (firwin) method\n", 120 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 121 | "- Lower passband edge: 8.00\n", 122 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 123 | "- Upper passband edge: 15.00 Hz\n", 124 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n", 125 | "- Filter length: 265 samples (1.656 s)\n", 126 | "\n", 127 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R06.edf...\n", 128 | "EDF file detected\n", 129 | "Setting channel info structure...\n", 130 | "Creating raw.info structure...\n", 131 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 132 | "Filtering raw data in 1 contiguous segment\n", 133 | "Setting up band-pass filter from 8 - 15 Hz\n", 134 | "\n", 135 | "FIR filter parameters\n", 136 | "---------------------\n", 137 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 138 | "- Windowed time-domain design (firwin) method\n", 139 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 140 | "- Lower passband edge: 8.00\n", 141 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 142 | "- Upper passband edge: 15.00 Hz\n", 143 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n", 144 | "- Filter length: 265 samples (1.656 s)\n" 145 | ] 146 | }, 147 | { 148 | "name": "stderr", 149 | "output_type": "stream", 150 | "text": [ 151 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 152 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 153 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 154 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n" 155 | ] 156 | }, 157 | { 158 | "name": "stdout", 159 | "output_type": "stream", 160 | "text": [ 161 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R07.edf...\n", 162 | "EDF file detected\n", 163 | "Setting channel info structure...\n", 164 | "Creating raw.info structure...\n", 165 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 166 | "Filtering raw data in 1 contiguous segment\n", 167 | "Setting up band-pass filter from 8 - 15 Hz\n", 168 | "\n", 169 | "FIR filter parameters\n", 170 | "---------------------\n", 171 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 172 | "- Windowed time-domain design (firwin) method\n", 173 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 174 | "- Lower passband edge: 8.00\n", 175 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 176 | "- Upper passband edge: 15.00 Hz\n", 177 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n", 178 | "- Filter length: 265 samples (1.656 s)\n", 179 | "\n", 180 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R08.edf...\n", 181 | "EDF file detected\n", 182 | "Setting channel info structure...\n", 183 | "Creating raw.info structure...\n", 184 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 185 | "Filtering raw data in 1 contiguous segment\n", 186 | "Setting up band-pass filter from 8 - 15 Hz\n", 187 | "\n", 188 | "FIR filter parameters\n", 189 | "---------------------\n", 190 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 191 | "- Windowed time-domain design (firwin) method\n", 192 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 193 | "- Lower passband edge: 8.00\n", 194 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 195 | "- Upper passband edge: 15.00 Hz\n", 196 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n", 197 | "- Filter length: 265 samples (1.656 s)\n", 198 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R09.edf...\n", 199 | "EDF file detected\n", 200 | "Setting channel info structure...\n", 201 | "Creating raw.info structure...\n", 202 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 203 | "Filtering raw data in 1 contiguous segment\n", 204 | "Setting up band-pass filter from 8 - 15 Hz\n", 205 | "\n", 206 | "FIR filter parameters\n", 207 | "---------------------\n", 208 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 209 | "- Windowed time-domain design (firwin) method\n", 210 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 211 | "- Lower passband edge: 8.00\n", 212 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 213 | "- Upper passband edge: 15.00 Hz\n", 214 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n", 215 | "- Filter length: 265 samples (1.656 s)\n", 216 | "\n", 217 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R10.edf...\n", 218 | "EDF file detected\n", 219 | "Setting channel info structure...\n", 220 | "Creating raw.info structure...\n", 221 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 222 | "Filtering raw data in 1 contiguous segment\n", 223 | "Setting up band-pass filter from 8 - 15 Hz\n", 224 | "\n", 225 | "FIR filter parameters\n", 226 | "---------------------\n", 227 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 228 | "- Windowed time-domain design (firwin) method\n", 229 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 230 | "- Lower passband edge: 8.00\n", 231 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 232 | "- Upper passband edge: 15.00 Hz\n", 233 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n", 234 | "- Filter length: 265 samples (1.656 s)\n" 235 | ] 236 | }, 237 | { 238 | "name": "stderr", 239 | "output_type": "stream", 240 | "text": [ 241 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 242 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 243 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 244 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n" 245 | ] 246 | }, 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R11.edf...\n", 252 | "EDF file detected\n", 253 | "Setting channel info structure...\n", 254 | "Creating raw.info structure...\n", 255 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 256 | "Filtering raw data in 1 contiguous segment\n", 257 | "Setting up band-pass filter from 8 - 15 Hz\n", 258 | "\n", 259 | "FIR filter parameters\n", 260 | "---------------------\n", 261 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 262 | "- Windowed time-domain design (firwin) method\n", 263 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 264 | "- Lower passband edge: 8.00\n", 265 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 266 | "- Upper passband edge: 15.00 Hz\n", 267 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n", 268 | "- Filter length: 265 samples (1.656 s)\n", 269 | "\n", 270 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R12.edf...\n", 271 | "EDF file detected\n", 272 | "Setting channel info structure...\n", 273 | "Creating raw.info structure...\n", 274 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 275 | "Filtering raw data in 1 contiguous segment\n", 276 | "Setting up band-pass filter from 8 - 15 Hz\n", 277 | "\n", 278 | "FIR filter parameters\n", 279 | "---------------------\n", 280 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 281 | "- Windowed time-domain design (firwin) method\n", 282 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 283 | "- Lower passband edge: 8.00\n", 284 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 285 | "- Upper passband edge: 15.00 Hz\n", 286 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n", 287 | "- Filter length: 265 samples (1.656 s)\n", 288 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R13.edf...\n", 289 | "EDF file detected\n", 290 | "Setting channel info structure...\n", 291 | "Creating raw.info structure...\n", 292 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 293 | "Filtering raw data in 1 contiguous segment\n", 294 | "Setting up band-pass filter from 8 - 15 Hz\n", 295 | "\n", 296 | "FIR filter parameters\n", 297 | "---------------------\n", 298 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 299 | "- Windowed time-domain design (firwin) method\n", 300 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 301 | "- Lower passband edge: 8.00\n", 302 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 303 | "- Upper passband edge: 15.00 Hz\n", 304 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n", 305 | "- Filter length: 265 samples (1.656 s)\n", 306 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\EEG-files\\S001\\S001R11.edf...\n", 307 | "EDF file detected\n", 308 | "Setting channel info structure...\n", 309 | "Creating raw.info structure...\n", 310 | "Reading 0 ... 19999 = 0.000 ... 124.994 secs...\n", 311 | "Filtering raw data in 1 contiguous segment\n", 312 | "Setting up band-pass filter from 8 - 15 Hz\n", 313 | "\n", 314 | "FIR filter parameters\n", 315 | "---------------------\n", 316 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 317 | "- Windowed time-domain design (firwin) method\n", 318 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 319 | "- Lower passband edge: 8.00\n", 320 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 321 | "- Upper passband edge: 15.00 Hz\n", 322 | "- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)\n" 323 | ] 324 | }, 325 | { 326 | "name": "stderr", 327 | "output_type": "stream", 328 | "text": [ 329 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 330 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 331 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n" 332 | ] 333 | }, 334 | { 335 | "name": "stdout", 336 | "output_type": "stream", 337 | "text": [ 338 | "- Filter length: 265 samples (1.656 s)\n" 339 | ] 340 | }, 341 | { 342 | "name": "stderr", 343 | "output_type": "stream", 344 | "text": [ 345 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n" 346 | ] 347 | }, 348 | { 349 | "data": { 350 | "text/plain": "", 351 | "text/html": "\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Measurement dateAugust 12, 2009 16:15:00 GMT
ExperimenterUnknown
ParticipantX
Digitized pointsNot available
Good channels64 EEG
Bad channelsNone
EOG channelsNot available
ECG channelsNot available
Sampling frequency160.00 Hz
Highpass8.00 Hz
Lowpass15.00 Hz
FilenamesS001R11.edf
Duration00:02:05 (HH:MM:SS)
" 352 | }, 353 | "execution_count": 227, 354 | "metadata": {}, 355 | "output_type": "execute_result" 356 | } 357 | ], 358 | "source": [ 359 | "raw = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R03.edf',preload=True)\n", 360 | "raw.filter(8,15)\n", 361 | "raw1 = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R04.edf',preload=True)\n", 362 | "raw1.filter(8,15)\n", 363 | "raw2 = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R05.edf',preload=True)\n", 364 | "raw2.filter(8,15)\n", 365 | "raw3 = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R06.edf',preload=True)\n", 366 | "raw3.filter(8,15)\n", 367 | "raw4 = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R07.edf',preload=True)\n", 368 | "raw4.filter(8,15)\n", 369 | "raw5 = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R08.edf',preload=True)\n", 370 | "raw5.filter(8,15)\n", 371 | "raw6 = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R09.edf',preload=True)\n", 372 | "raw6.filter(8,15)\n", 373 | "raw7 = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R10.edf',preload=True)\n", 374 | "raw7.filter(8,15)\n", 375 | "raw8 = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R11.edf',preload=True)\n", 376 | "raw8.filter(8,15)\n", 377 | "raw9 = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R12.edf',preload=True)\n", 378 | "raw9.filter(8,15)\n", 379 | "raw10 = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R13.edf',preload=True)\n", 380 | "raw10.filter(8,15)\n", 381 | "raw11 = mne.io.read_raw_edf('C:\\\\Users\\\\24242\\\\Desktop\\\\AI_Reference\\\\data_bag\\\\EEG-files\\\\S001\\\\S001R11.edf',preload=True)\n", 382 | "raw11.filter(8,15)" 383 | ], 384 | "metadata": { 385 | "collapsed": false, 386 | "ExecuteTime": { 387 | "end_time": "2024-01-26T08:57:45.001638300Z", 388 | "start_time": "2024-01-26T08:57:44.178689800Z" 389 | } 390 | }, 391 | "id": "1aa20ec1436414a4" 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 228, 396 | "outputs": [], 397 | "source": [ 398 | "def choose_event_Epochs_data(raw):\n", 399 | " event,event_id = mne.events_from_annotations(raw)\n", 400 | " event_pick_id = [2,3]\n", 401 | " pick_event = mne.pick_events(event,include=event_pick_id)\n", 402 | " event_new_id = {'T1':2,'T2':3}\n", 403 | " epochs = mne.Epochs(raw,pick_event,event_new_id,tmin=-1,tmax=5,preload=True)\n", 404 | " label = epochs.events[:,-1]\n", 405 | " epochs.load_data().filter(l_freq=8,h_freq=12)\n", 406 | " data = epochs.get_data()\n", 407 | " print(data.shape,label.shape)\n", 408 | " return data,label" 409 | ], 410 | "metadata": { 411 | "collapsed": false, 412 | "ExecuteTime": { 413 | "end_time": "2024-01-26T08:57:45.017373900Z", 414 | "start_time": "2024-01-26T08:57:44.937289900Z" 415 | } 416 | }, 417 | "id": "4f352135ec155406" 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 229, 422 | "outputs": [ 423 | { 424 | "name": "stdout", 425 | "output_type": "stream", 426 | "text": [ 427 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 428 | "Not setting metadata\n", 429 | "15 matching events found\n", 430 | "Setting baseline interval to [-1.0, 0.0] s\n", 431 | "Applying baseline correction (mode: mean)\n", 432 | "0 projection items activated\n", 433 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 434 | "1 bad epochs dropped\n", 435 | "Setting up band-pass filter from 8 - 12 Hz\n", 436 | "\n", 437 | "FIR filter parameters\n", 438 | "---------------------\n", 439 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 440 | "- Windowed time-domain design (firwin) method\n", 441 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 442 | "- Lower passband edge: 8.00\n", 443 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 444 | "- Upper passband edge: 12.00 Hz\n", 445 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 446 | "- Filter length: 265 samples (1.656 s)\n", 447 | "(14, 64, 961) (14,)\n" 448 | ] 449 | }, 450 | { 451 | "name": "stderr", 452 | "output_type": "stream", 453 | "text": [ 454 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 455 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 456 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 457 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n", 458 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n", 459 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n", 460 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n" 461 | ] 462 | } 463 | ], 464 | "source": [ 465 | "data,label = choose_event_Epochs_data(raw)" 466 | ], 467 | "metadata": { 468 | "collapsed": false, 469 | "ExecuteTime": { 470 | "end_time": "2024-01-26T08:57:45.142484700Z", 471 | "start_time": "2024-01-26T08:57:44.952979700Z" 472 | } 473 | }, 474 | "id": "81c3965c45e75e0a" 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 230, 479 | "outputs": [ 480 | { 481 | "name": "stdout", 482 | "output_type": "stream", 483 | "text": [ 484 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 485 | "Not setting metadata\n", 486 | "15 matching events found\n", 487 | "Setting baseline interval to [-1.0, 0.0] s\n", 488 | "Applying baseline correction (mode: mean)\n", 489 | "0 projection items activated\n", 490 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 491 | "1 bad epochs dropped\n", 492 | "Setting up band-pass filter from 8 - 12 Hz\n", 493 | "\n", 494 | "FIR filter parameters\n", 495 | "---------------------\n", 496 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 497 | "- Windowed time-domain design (firwin) method\n", 498 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 499 | "- Lower passband edge: 8.00\n", 500 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 501 | "- Upper passband edge: 12.00 Hz\n", 502 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 503 | "- Filter length: 265 samples (1.656 s)\n" 504 | ] 505 | }, 506 | { 507 | "name": "stderr", 508 | "output_type": "stream", 509 | "text": [ 510 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 511 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 512 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 513 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n" 514 | ] 515 | }, 516 | { 517 | "name": "stdout", 518 | "output_type": "stream", 519 | "text": [ 520 | "(14, 64, 961) (14,)\n" 521 | ] 522 | }, 523 | { 524 | "name": "stderr", 525 | "output_type": "stream", 526 | "text": [ 527 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n", 528 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n", 529 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n" 530 | ] 531 | } 532 | ], 533 | "source": [ 534 | "data1,label1 = choose_event_Epochs_data(raw1)" 535 | ], 536 | "metadata": { 537 | "collapsed": false, 538 | "ExecuteTime": { 539 | "end_time": "2024-01-26T08:57:45.335822300Z", 540 | "start_time": "2024-01-26T08:57:45.093949700Z" 541 | } 542 | }, 543 | "id": "e551d1f286d8aa54" 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 231, 548 | "outputs": [ 549 | { 550 | "name": "stdout", 551 | "output_type": "stream", 552 | "text": [ 553 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 554 | "Not setting metadata\n", 555 | "15 matching events found\n", 556 | "Setting baseline interval to [-1.0, 0.0] s\n", 557 | "Applying baseline correction (mode: mean)\n", 558 | "0 projection items activated\n", 559 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 560 | "1 bad epochs dropped\n", 561 | "Setting up band-pass filter from 8 - 12 Hz\n", 562 | "\n", 563 | "FIR filter parameters\n", 564 | "---------------------\n", 565 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 566 | "- Windowed time-domain design (firwin) method\n", 567 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 568 | "- Lower passband edge: 8.00\n", 569 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 570 | "- Upper passband edge: 12.00 Hz\n", 571 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 572 | "- Filter length: 265 samples (1.656 s)\n" 573 | ] 574 | }, 575 | { 576 | "name": "stderr", 577 | "output_type": "stream", 578 | "text": [ 579 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 580 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 581 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 582 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n", 583 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n" 584 | ] 585 | }, 586 | { 587 | "name": "stdout", 588 | "output_type": "stream", 589 | "text": [ 590 | "(14, 64, 961) (14,)\n" 591 | ] 592 | }, 593 | { 594 | "name": "stderr", 595 | "output_type": "stream", 596 | "text": [ 597 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n", 598 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n" 599 | ] 600 | } 601 | ], 602 | "source": [ 603 | "data2,label2 = choose_event_Epochs_data(raw2)" 604 | ], 605 | "metadata": { 606 | "collapsed": false, 607 | "ExecuteTime": { 608 | "end_time": "2024-01-26T08:57:45.351464700Z", 609 | "start_time": "2024-01-26T08:57:45.221222100Z" 610 | } 611 | }, 612 | "id": "4f32bbf225ef9e12" 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 232, 617 | "outputs": [ 618 | { 619 | "name": "stdout", 620 | "output_type": "stream", 621 | "text": [ 622 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 623 | "Not setting metadata\n", 624 | "15 matching events found\n", 625 | "Setting baseline interval to [-1.0, 0.0] s\n", 626 | "Applying baseline correction (mode: mean)\n", 627 | "0 projection items activated\n", 628 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 629 | "1 bad epochs dropped\n", 630 | "Setting up band-pass filter from 8 - 12 Hz\n", 631 | "\n", 632 | "FIR filter parameters\n", 633 | "---------------------\n", 634 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 635 | "- Windowed time-domain design (firwin) method\n", 636 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 637 | "- Lower passband edge: 8.00\n", 638 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 639 | "- Upper passband edge: 12.00 Hz\n", 640 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 641 | "- Filter length: 265 samples (1.656 s)\n" 642 | ] 643 | }, 644 | { 645 | "name": "stderr", 646 | "output_type": "stream", 647 | "text": [ 648 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 649 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 650 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 651 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n", 652 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n" 653 | ] 654 | }, 655 | { 656 | "name": "stdout", 657 | "output_type": "stream", 658 | "text": [ 659 | "(14, 64, 961) (14,)\n" 660 | ] 661 | }, 662 | { 663 | "name": "stderr", 664 | "output_type": "stream", 665 | "text": [ 666 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n", 667 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n" 668 | ] 669 | } 670 | ], 671 | "source": [ 672 | "data3,label3 = choose_event_Epochs_data(raw3)" 673 | ], 674 | "metadata": { 675 | "collapsed": false, 676 | "ExecuteTime": { 677 | "end_time": "2024-01-26T08:57:45.493627200Z", 678 | "start_time": "2024-01-26T08:57:45.351464700Z" 679 | } 680 | }, 681 | "id": "e309ac74c728cbc4" 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": 233, 686 | "outputs": [ 687 | { 688 | "name": "stdout", 689 | "output_type": "stream", 690 | "text": [ 691 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 692 | "Not setting metadata\n", 693 | "15 matching events found\n", 694 | "Setting baseline interval to [-1.0, 0.0] s\n", 695 | "Applying baseline correction (mode: mean)\n", 696 | "0 projection items activated\n", 697 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 698 | "1 bad epochs dropped\n", 699 | "Setting up band-pass filter from 8 - 12 Hz\n", 700 | "\n", 701 | "FIR filter parameters\n", 702 | "---------------------\n", 703 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 704 | "- Windowed time-domain design (firwin) method\n", 705 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 706 | "- Lower passband edge: 8.00\n", 707 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 708 | "- Upper passband edge: 12.00 Hz\n", 709 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 710 | "- Filter length: 265 samples (1.656 s)\n" 711 | ] 712 | }, 713 | { 714 | "name": "stderr", 715 | "output_type": "stream", 716 | "text": [ 717 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 718 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 719 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 720 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n", 721 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n", 722 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n" 723 | ] 724 | }, 725 | { 726 | "name": "stdout", 727 | "output_type": "stream", 728 | "text": [ 729 | "(14, 64, 961) (14,)\n" 730 | ] 731 | }, 732 | { 733 | "name": "stderr", 734 | "output_type": "stream", 735 | "text": [ 736 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n" 737 | ] 738 | } 739 | ], 740 | "source": [ 741 | "data4,label4 = choose_event_Epochs_data(raw4)" 742 | ], 743 | "metadata": { 744 | "collapsed": false, 745 | "ExecuteTime": { 746 | "end_time": "2024-01-26T08:57:45.603932Z", 747 | "start_time": "2024-01-26T08:57:45.477907Z" 748 | } 749 | }, 750 | "id": "e0b4bbd9dd75139e" 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": 234, 755 | "outputs": [ 756 | { 757 | "name": "stdout", 758 | "output_type": "stream", 759 | "text": [ 760 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 761 | "Not setting metadata\n", 762 | "15 matching events found\n", 763 | "Setting baseline interval to [-1.0, 0.0] s\n", 764 | "Applying baseline correction (mode: mean)\n", 765 | "0 projection items activated\n", 766 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 767 | "1 bad epochs dropped\n", 768 | "Setting up band-pass filter from 8 - 12 Hz\n", 769 | "\n", 770 | "FIR filter parameters\n", 771 | "---------------------\n", 772 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 773 | "- Windowed time-domain design (firwin) method\n", 774 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 775 | "- Lower passband edge: 8.00\n", 776 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 777 | "- Upper passband edge: 12.00 Hz\n", 778 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 779 | "- Filter length: 265 samples (1.656 s)\n" 780 | ] 781 | }, 782 | { 783 | "name": "stderr", 784 | "output_type": "stream", 785 | "text": [ 786 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 787 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 788 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 789 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n", 790 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n" 791 | ] 792 | }, 793 | { 794 | "name": "stdout", 795 | "output_type": "stream", 796 | "text": [ 797 | "(14, 64, 961) (14,)\n" 798 | ] 799 | }, 800 | { 801 | "name": "stderr", 802 | "output_type": "stream", 803 | "text": [ 804 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n", 805 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n" 806 | ] 807 | } 808 | ], 809 | "source": [ 810 | "data5,label5 = choose_event_Epochs_data(raw5)" 811 | ], 812 | "metadata": { 813 | "collapsed": false, 814 | "ExecuteTime": { 815 | "end_time": "2024-01-26T08:57:45.735943600Z", 816 | "start_time": "2024-01-26T08:57:45.603932Z" 817 | } 818 | }, 819 | "id": "67d0b341927b8e08" 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": 235, 824 | "outputs": [ 825 | { 826 | "name": "stdout", 827 | "output_type": "stream", 828 | "text": [ 829 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 830 | "Not setting metadata\n", 831 | "15 matching events found\n", 832 | "Setting baseline interval to [-1.0, 0.0] s\n", 833 | "Applying baseline correction (mode: mean)\n", 834 | "0 projection items activated\n", 835 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 836 | "1 bad epochs dropped\n", 837 | "Setting up band-pass filter from 8 - 12 Hz\n", 838 | "\n", 839 | "FIR filter parameters\n", 840 | "---------------------\n", 841 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 842 | "- Windowed time-domain design (firwin) method\n", 843 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 844 | "- Lower passband edge: 8.00\n", 845 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 846 | "- Upper passband edge: 12.00 Hz\n", 847 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 848 | "- Filter length: 265 samples (1.656 s)\n" 849 | ] 850 | }, 851 | { 852 | "name": "stderr", 853 | "output_type": "stream", 854 | "text": [ 855 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 856 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 857 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 858 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n", 859 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n", 860 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n" 861 | ] 862 | }, 863 | { 864 | "name": "stdout", 865 | "output_type": "stream", 866 | "text": [ 867 | "(14, 64, 961) (14,)\n" 868 | ] 869 | }, 870 | { 871 | "name": "stderr", 872 | "output_type": "stream", 873 | "text": [ 874 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n" 875 | ] 876 | } 877 | ], 878 | "source": [ 879 | "data6,label6 = choose_event_Epochs_data(raw6)" 880 | ], 881 | "metadata": { 882 | "collapsed": false, 883 | "ExecuteTime": { 884 | "end_time": "2024-01-26T08:57:45.857502Z", 885 | "start_time": "2024-01-26T08:57:45.730934200Z" 886 | } 887 | }, 888 | "id": "7a670ba615b4140c" 889 | }, 890 | { 891 | "cell_type": "code", 892 | "execution_count": 236, 893 | "outputs": [ 894 | { 895 | "name": "stdout", 896 | "output_type": "stream", 897 | "text": [ 898 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 899 | "Not setting metadata\n", 900 | "15 matching events found\n", 901 | "Setting baseline interval to [-1.0, 0.0] s\n", 902 | "Applying baseline correction (mode: mean)\n", 903 | "0 projection items activated\n", 904 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 905 | "1 bad epochs dropped\n", 906 | "Setting up band-pass filter from 8 - 12 Hz\n", 907 | "\n", 908 | "FIR filter parameters\n", 909 | "---------------------\n", 910 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 911 | "- Windowed time-domain design (firwin) method\n", 912 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 913 | "- Lower passband edge: 8.00\n", 914 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 915 | "- Upper passband edge: 12.00 Hz\n", 916 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 917 | "- Filter length: 265 samples (1.656 s)\n" 918 | ] 919 | }, 920 | { 921 | "name": "stderr", 922 | "output_type": "stream", 923 | "text": [ 924 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 925 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 926 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 927 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n", 928 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n" 929 | ] 930 | }, 931 | { 932 | "name": "stdout", 933 | "output_type": "stream", 934 | "text": [ 935 | "(14, 64, 961) (14,)\n" 936 | ] 937 | }, 938 | { 939 | "name": "stderr", 940 | "output_type": "stream", 941 | "text": [ 942 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n", 943 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n" 944 | ] 945 | } 946 | ], 947 | "source": [ 948 | "data7,label7 = choose_event_Epochs_data(raw7)" 949 | ], 950 | "metadata": { 951 | "collapsed": false, 952 | "ExecuteTime": { 953 | "end_time": "2024-01-26T08:57:46.031457300Z", 954 | "start_time": "2024-01-26T08:57:45.857502Z" 955 | } 956 | }, 957 | "id": "bd7730d2c855cc29" 958 | }, 959 | { 960 | "cell_type": "code", 961 | "execution_count": 237, 962 | "outputs": [ 963 | { 964 | "name": "stdout", 965 | "output_type": "stream", 966 | "text": [ 967 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 968 | "Not setting metadata\n", 969 | "15 matching events found\n", 970 | "Setting baseline interval to [-1.0, 0.0] s\n", 971 | "Applying baseline correction (mode: mean)\n", 972 | "0 projection items activated\n", 973 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 974 | "1 bad epochs dropped\n", 975 | "Setting up band-pass filter from 8 - 12 Hz\n", 976 | "\n", 977 | "FIR filter parameters\n", 978 | "---------------------\n", 979 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 980 | "- Windowed time-domain design (firwin) method\n", 981 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 982 | "- Lower passband edge: 8.00\n", 983 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 984 | "- Upper passband edge: 12.00 Hz\n", 985 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 986 | "- Filter length: 265 samples (1.656 s)\n" 987 | ] 988 | }, 989 | { 990 | "name": "stderr", 991 | "output_type": "stream", 992 | "text": [ 993 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 994 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 995 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 996 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n", 997 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n", 998 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n", 999 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n" 1000 | ] 1001 | }, 1002 | { 1003 | "name": "stdout", 1004 | "output_type": "stream", 1005 | "text": [ 1006 | "(14, 64, 961) (14,)\n" 1007 | ] 1008 | } 1009 | ], 1010 | "source": [ 1011 | "data8,label8 = choose_event_Epochs_data(raw8)" 1012 | ], 1013 | "metadata": { 1014 | "collapsed": false, 1015 | "ExecuteTime": { 1016 | "end_time": "2024-01-26T08:57:46.110297200Z", 1017 | "start_time": "2024-01-26T08:57:45.983428200Z" 1018 | } 1019 | }, 1020 | "id": "d96506677fae1ea2" 1021 | }, 1022 | { 1023 | "cell_type": "code", 1024 | "execution_count": 238, 1025 | "outputs": [ 1026 | { 1027 | "name": "stdout", 1028 | "output_type": "stream", 1029 | "text": [ 1030 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 1031 | "Not setting metadata\n", 1032 | "15 matching events found\n", 1033 | "Setting baseline interval to [-1.0, 0.0] s\n", 1034 | "Applying baseline correction (mode: mean)\n", 1035 | "0 projection items activated\n", 1036 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 1037 | "1 bad epochs dropped\n", 1038 | "Setting up band-pass filter from 8 - 12 Hz\n", 1039 | "\n", 1040 | "FIR filter parameters\n", 1041 | "---------------------\n", 1042 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 1043 | "- Windowed time-domain design (firwin) method\n", 1044 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 1045 | "- Lower passband edge: 8.00\n", 1046 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 1047 | "- Upper passband edge: 12.00 Hz\n", 1048 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 1049 | "- Filter length: 265 samples (1.656 s)\n" 1050 | ] 1051 | }, 1052 | { 1053 | "name": "stderr", 1054 | "output_type": "stream", 1055 | "text": [ 1056 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 1057 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 1058 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 1059 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n", 1060 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n", 1061 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n" 1062 | ] 1063 | }, 1064 | { 1065 | "name": "stdout", 1066 | "output_type": "stream", 1067 | "text": [ 1068 | "(14, 64, 961) (14,)\n" 1069 | ] 1070 | }, 1071 | { 1072 | "name": "stderr", 1073 | "output_type": "stream", 1074 | "text": [ 1075 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n" 1076 | ] 1077 | } 1078 | ], 1079 | "source": [ 1080 | "data9,label9 = choose_event_Epochs_data(raw9)" 1081 | ], 1082 | "metadata": { 1083 | "collapsed": false, 1084 | "ExecuteTime": { 1085 | "end_time": "2024-01-26T08:57:46.268207500Z", 1086 | "start_time": "2024-01-26T08:57:46.110297200Z" 1087 | } 1088 | }, 1089 | "id": "a5c338594250d135" 1090 | }, 1091 | { 1092 | "cell_type": "code", 1093 | "execution_count": 239, 1094 | "outputs": [ 1095 | { 1096 | "name": "stdout", 1097 | "output_type": "stream", 1098 | "text": [ 1099 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 1100 | "Not setting metadata\n", 1101 | "15 matching events found\n", 1102 | "Setting baseline interval to [-1.0, 0.0] s\n", 1103 | "Applying baseline correction (mode: mean)\n", 1104 | "0 projection items activated\n", 1105 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 1106 | "1 bad epochs dropped\n", 1107 | "Setting up band-pass filter from 8 - 12 Hz\n", 1108 | "\n", 1109 | "FIR filter parameters\n", 1110 | "---------------------\n", 1111 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 1112 | "- Windowed time-domain design (firwin) method\n", 1113 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 1114 | "- Lower passband edge: 8.00\n", 1115 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 1116 | "- Upper passband edge: 12.00 Hz\n", 1117 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 1118 | "- Filter length: 265 samples (1.656 s)\n" 1119 | ] 1120 | }, 1121 | { 1122 | "name": "stderr", 1123 | "output_type": "stream", 1124 | "text": [ 1125 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 1126 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 1127 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 1128 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n", 1129 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n", 1130 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n" 1131 | ] 1132 | }, 1133 | { 1134 | "name": "stdout", 1135 | "output_type": "stream", 1136 | "text": [ 1137 | "(14, 64, 961) (14,)\n", 1138 | "Used Annotations descriptions: ['T0', 'T1', 'T2']\n", 1139 | "Not setting metadata\n", 1140 | "15 matching events found\n", 1141 | "Setting baseline interval to [-1.0, 0.0] s\n", 1142 | "Applying baseline correction (mode: mean)\n", 1143 | "0 projection items activated\n", 1144 | "Using data from preloaded Raw for 15 events and 961 original time points ...\n", 1145 | "1 bad epochs dropped\n", 1146 | "Setting up band-pass filter from 8 - 12 Hz\n", 1147 | "\n", 1148 | "FIR filter parameters\n", 1149 | "---------------------\n", 1150 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 1151 | "- Windowed time-domain design (firwin) method\n", 1152 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 1153 | "- Lower passband edge: 8.00\n", 1154 | "- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)\n", 1155 | "- Upper passband edge: 12.00 Hz\n", 1156 | "- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)\n", 1157 | "- Filter length: 265 samples (1.656 s)\n" 1158 | ] 1159 | }, 1160 | { 1161 | "name": "stderr", 1162 | "output_type": "stream", 1163 | "text": [ 1164 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n", 1165 | "[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s\n", 1166 | "[Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s\n", 1167 | "[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s\n", 1168 | "[Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s\n", 1169 | "[Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.0s\n", 1170 | "[Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.0s\n" 1171 | ] 1172 | }, 1173 | { 1174 | "name": "stdout", 1175 | "output_type": "stream", 1176 | "text": [ 1177 | "(14, 64, 961) (14,)\n" 1178 | ] 1179 | }, 1180 | { 1181 | "name": "stderr", 1182 | "output_type": "stream", 1183 | "text": [ 1184 | "[Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.0s\n" 1185 | ] 1186 | } 1187 | ], 1188 | "source": [ 1189 | "data10,label10 = choose_event_Epochs_data(raw10)\n", 1190 | "data11,label11 = choose_event_Epochs_data(raw11)" 1191 | ], 1192 | "metadata": { 1193 | "collapsed": false, 1194 | "ExecuteTime": { 1195 | "end_time": "2024-01-26T08:57:46.505247700Z", 1196 | "start_time": "2024-01-26T08:57:46.236753Z" 1197 | } 1198 | }, 1199 | "id": "a89eb50d8a898de1" 1200 | }, 1201 | { 1202 | "cell_type": "code", 1203 | "execution_count": 240, 1204 | "outputs": [], 1205 | "source": [ 1206 | "def gather_data(data,data1,label,label1):\n", 1207 | " data2 = np.concatenate((data,data1),axis=0)\n", 1208 | " label2 = np.append(label,label1)\n", 1209 | " return data2,label2\n", 1210 | " " 1211 | ], 1212 | "metadata": { 1213 | "collapsed": false, 1214 | "ExecuteTime": { 1215 | "end_time": "2024-01-26T08:57:46.505247700Z", 1216 | "start_time": "2024-01-26T08:57:46.473775500Z" 1217 | } 1218 | }, 1219 | "id": "e61ab4a26426ef41" 1220 | }, 1221 | { 1222 | "cell_type": "code", 1223 | "execution_count": 241, 1224 | "outputs": [], 1225 | "source": [ 1226 | "data_2,label_2 = gather_data(data,data1,label,label1)\n", 1227 | "data_3,label_3 = gather_data(data_2,data2,label_2,label2)\n", 1228 | "data_4,label_4 = gather_data(data_3,data3,label_3,label3)\n", 1229 | "data_5,label_5 = gather_data(data_4,data4,label_4,label4)\n", 1230 | "data_6,label_6 = gather_data(data_5,data5,label_5,label5)\n", 1231 | "data_all,label_all = gather_data(data_6,data6,label_6,label6)\n", 1232 | "data_all,label_all = gather_data(data_all,data7,label_all,label7)\n", 1233 | "Test_data,Test_label = gather_data(data8,data9,label8,label9)\n", 1234 | "Test_data,Test_label = gather_data(Test_data,data10,Test_label,label10)\n", 1235 | "Test_data,Test_label = gather_data(Test_data,data11,Test_label,label11)" 1236 | ], 1237 | "metadata": { 1238 | "collapsed": false, 1239 | "ExecuteTime": { 1240 | "end_time": "2024-01-26T08:57:46.583596200Z", 1241 | "start_time": "2024-01-26T08:57:46.489554300Z" 1242 | } 1243 | }, 1244 | "id": "c82ba8e5bdd55537" 1245 | }, 1246 | { 1247 | "cell_type": "code", 1248 | "execution_count": 242, 1249 | "outputs": [ 1250 | { 1251 | "data": { 1252 | "text/plain": "(112, 64, 961)" 1253 | }, 1254 | "execution_count": 242, 1255 | "metadata": {}, 1256 | "output_type": "execute_result" 1257 | } 1258 | ], 1259 | "source": [ 1260 | "data_all.shape" 1261 | ], 1262 | "metadata": { 1263 | "collapsed": false, 1264 | "ExecuteTime": { 1265 | "end_time": "2024-01-26T08:57:46.583596200Z", 1266 | "start_time": "2024-01-26T08:57:46.552232100Z" 1267 | } 1268 | }, 1269 | "id": "5f4f13520c78ff58" 1270 | }, 1271 | { 1272 | "cell_type": "code", 1273 | "execution_count": 243, 1274 | "outputs": [ 1275 | { 1276 | "data": { 1277 | "text/plain": "(112,)" 1278 | }, 1279 | "execution_count": 243, 1280 | "metadata": {}, 1281 | "output_type": "execute_result" 1282 | } 1283 | ], 1284 | "source": [ 1285 | "label_all.shape" 1286 | ], 1287 | "metadata": { 1288 | "collapsed": false, 1289 | "ExecuteTime": { 1290 | "end_time": "2024-01-26T08:57:46.583596200Z", 1291 | "start_time": "2024-01-26T08:57:46.567858800Z" 1292 | } 1293 | }, 1294 | "id": "ef4d80c8e71501e2" 1295 | }, 1296 | { 1297 | "cell_type": "code", 1298 | "execution_count": 244, 1299 | "outputs": [ 1300 | { 1301 | "data": { 1302 | "text/plain": "((56, 64, 961), (56,))" 1303 | }, 1304 | "execution_count": 244, 1305 | "metadata": {}, 1306 | "output_type": "execute_result" 1307 | } 1308 | ], 1309 | "source": [ 1310 | "Test_data.shape,Test_label.shape" 1311 | ], 1312 | "metadata": { 1313 | "collapsed": false, 1314 | "ExecuteTime": { 1315 | "end_time": "2024-01-26T08:57:46.599326800Z", 1316 | "start_time": "2024-01-26T08:57:46.583596200Z" 1317 | } 1318 | }, 1319 | "id": "88f48d1769e1e1f3" 1320 | }, 1321 | { 1322 | "cell_type": "code", 1323 | "execution_count": 245, 1324 | "outputs": [ 1325 | { 1326 | "data": { 1327 | "text/plain": "((112, 10), (56, 10))" 1328 | }, 1329 | "execution_count": 245, 1330 | "metadata": {}, 1331 | "output_type": "execute_result" 1332 | } 1333 | ], 1334 | "source": [ 1335 | "csp = CSP(n_components=10)\n", 1336 | "X_csp_all = csp.fit_transform(data_all,label_all)\n", 1337 | "Test_data_csp = csp.transform(Test_data)\n", 1338 | "X_csp_all.shape,Test_data_csp.shape" 1339 | ], 1340 | "metadata": { 1341 | "collapsed": false, 1342 | "ExecuteTime": { 1343 | "end_time": "2024-01-26T08:57:46.757773200Z", 1344 | "start_time": "2024-01-26T08:57:46.599326800Z" 1345 | } 1346 | }, 1347 | "id": "f6c3dc755cc4471e" 1348 | }, 1349 | { 1350 | "cell_type": "code", 1351 | "execution_count": 246, 1352 | "outputs": [], 1353 | "source": [ 1354 | "svc = SVC(kernel='linear')\n", 1355 | "svc.fit(X_csp_all,label_all)\n", 1356 | "y_pred = svc.predict(Test_data_csp)\n", 1357 | "acc = accuracy_score(Test_label,y_pred)" 1358 | ], 1359 | "metadata": { 1360 | "collapsed": false, 1361 | "ExecuteTime": { 1362 | "end_time": "2024-01-26T08:57:46.773527900Z", 1363 | "start_time": "2024-01-26T08:57:46.757773200Z" 1364 | } 1365 | }, 1366 | "id": "b9a2b439a7c72a2a" 1367 | }, 1368 | { 1369 | "cell_type": "code", 1370 | "execution_count": 247, 1371 | "outputs": [ 1372 | { 1373 | "data": { 1374 | "text/plain": "0.625" 1375 | }, 1376 | "execution_count": 247, 1377 | "metadata": {}, 1378 | "output_type": "execute_result" 1379 | } 1380 | ], 1381 | "source": [ 1382 | "acc" 1383 | ], 1384 | "metadata": { 1385 | "collapsed": false, 1386 | "ExecuteTime": { 1387 | "end_time": "2024-01-26T08:57:46.836602Z", 1388 | "start_time": "2024-01-26T08:57:46.773527900Z" 1389 | } 1390 | }, 1391 | "id": "24914a975aa48686" 1392 | }, 1393 | { 1394 | "cell_type": "code", 1395 | "execution_count": 248, 1396 | "outputs": [ 1397 | { 1398 | "name": "stdout", 1399 | "output_type": "stream", 1400 | "text": [ 1401 | "0.6785714285714286\n", 1402 | "[0.15153537 0.29535886 0.33919194 0.03493327 0.79198978 0.06977115\n", 1403 | " 0.00823536 0.32510551 0.71568288 0.05275308 0.14087436 0.88222796\n", 1404 | " 0.4045542 0.51859964 0.1742485 0.26558711 0.04140591 0.4023113\n", 1405 | " 0.04647999 0.48569648 0.69364207 0.02200009 0.16639753 0.59960677\n", 1406 | " 0.1656594 0.47406661 0.36024742 0.3653722 0.11608622 0.90631418\n", 1407 | " 0.08251557 0.97969391 0.8347348 0.57485275 0.83467668 0.03731981\n", 1408 | " 0.66060446 0.17135124 0.0134101 0.79913199 0.87322459 0.39165128\n", 1409 | " 0.15153537 0.29535886 0.33919194 0.03493327 0.79198978 0.06977115\n", 1410 | " 0.00823536 0.32510551 0.71568288 0.05275308 0.14087436 0.88222796\n", 1411 | " 0.4045542 0.51859964]\n", 1412 | "[2 2 2 2 3 2 2 2 3 2 2 3 2 3 2 2 2 2 2 2 3 2 2 3 2 2 2 2 2 3 2 3 3 3 3 2 3\n", 1413 | " 2 2 3 3 2 2 2 2 2 3 2 2 2 3 2 2 3 2 3]\n", 1414 | "[-1, 1, 1, -1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1]\n" 1415 | ] 1416 | } 1417 | ], 1418 | "source": [ 1419 | "lda = LinearDiscriminantAnalysis()\n", 1420 | "lda.fit(X_csp_all,label_all)\n", 1421 | "pred_lda = lda.predict(Test_data_csp)\n", 1422 | "print(accuracy_score(y_pred=pred_lda,y_true=Test_label))\n", 1423 | "y_score = lda.predict_proba(Test_data_csp)[:,1]\n", 1424 | "print(y_score)\n", 1425 | "print(lda.predict(Test_data_csp))\n", 1426 | "new_Test_label = []\n", 1427 | "for i in Test_label:\n", 1428 | " if i == 2:\n", 1429 | " new_Test_label.append(-1)\n", 1430 | " else:\n", 1431 | " new_Test_label.append(1)\n", 1432 | "print(new_Test_label)" 1433 | ], 1434 | "metadata": { 1435 | "collapsed": false, 1436 | "ExecuteTime": { 1437 | "end_time": "2024-01-26T08:57:46.837107500Z", 1438 | "start_time": "2024-01-26T08:57:46.789239200Z" 1439 | } 1440 | }, 1441 | "id": "9955628275c33567" 1442 | }, 1443 | { 1444 | "cell_type": "code", 1445 | "execution_count": 249, 1446 | "outputs": [], 1447 | "source": [ 1448 | "fpr,tpr,thresholds = roc_curve(y_true=new_Test_label,y_score=y_score)\n", 1449 | "roc_acu = auc(fpr,tpr)" 1450 | ], 1451 | "metadata": { 1452 | "collapsed": false, 1453 | "ExecuteTime": { 1454 | "end_time": "2024-01-26T08:57:46.837107500Z", 1455 | "start_time": "2024-01-26T08:57:46.805125900Z" 1456 | } 1457 | }, 1458 | "id": "9bb1451afef6273c" 1459 | }, 1460 | { 1461 | "cell_type": "code", 1462 | "execution_count": 250, 1463 | "outputs": [ 1464 | { 1465 | "data": { 1466 | "text/plain": "
", 1467 | "image/png": "" 1468 | }, 1469 | "metadata": {}, 1470 | "output_type": "display_data" 1471 | } 1472 | ], 1473 | "source": [ 1474 | "plt.figure(figsize=(8,8))\n", 1475 | "plt.plot(fpr,tpr,color='darkorange',lw=2,label='ROC curve(ACU={:.2f})'.format(roc_acu))\n", 1476 | "plt.plot([0, 1], [0, 1],color='navy',lw=2,linestyle='--',label='Random Guess')\n", 1477 | "\n", 1478 | "plt.xlabel('Flase Positive Rate')\n", 1479 | "plt.ylabel('True Positive Rate')\n", 1480 | "plt.title('Receiver Operating Characteristic (ROC) Curve')\n", 1481 | "plt.legend(loc='lower right')\n", 1482 | "plt.show()" 1483 | ], 1484 | "metadata": { 1485 | "collapsed": false, 1486 | "ExecuteTime": { 1487 | "end_time": "2024-01-26T08:57:46.962842200Z", 1488 | "start_time": "2024-01-26T08:57:46.820816600Z" 1489 | } 1490 | }, 1491 | "id": "955e75e3389d1fb" 1492 | }, 1493 | { 1494 | "cell_type": "code", 1495 | "execution_count": 251, 1496 | "outputs": [ 1497 | { 1498 | "name": "stdout", 1499 | "output_type": "stream", 1500 | "text": [ 1501 | "[0.8035714285714286, 0.7678571428571429, 0.7678571428571429, 0.6964285714285714, 0.6785714285714286, 0.6607142857142857, 0.6428571428571429, 0.7142857142857143, 0.7321428571428571, 0.6964285714285714]\n" 1502 | ] 1503 | } 1504 | ], 1505 | "source": [ 1506 | "# 计算精度与CSP共空间模式中n_component的数量对LDA分类器的影响\n", 1507 | "component_num = [2,4,6,8,10,12,14,16,18,20]\n", 1508 | "acc_list = []\n", 1509 | "for i in component_num:\n", 1510 | " csp = CSP(n_components=i)\n", 1511 | " X_csp_all = csp.fit_transform(data_all,label_all)\n", 1512 | " Test_data_csp = csp.transform(Test_data)\n", 1513 | " lda = LinearDiscriminantAnalysis()\n", 1514 | " lda.fit(X_csp_all,label_all)\n", 1515 | " pred_lda = lda.predict(Test_data_csp)\n", 1516 | " acc_lda = accuracy_score(y_true=Test_label,y_pred=pred_lda)\n", 1517 | " acc_list.append(acc_lda)\n", 1518 | "print(acc_list)" 1519 | ], 1520 | "metadata": { 1521 | "collapsed": false, 1522 | "ExecuteTime": { 1523 | "end_time": "2024-01-26T08:57:48.560734100Z", 1524 | "start_time": "2024-01-26T08:57:46.962842200Z" 1525 | } 1526 | }, 1527 | "id": "2f3911827c688610" 1528 | }, 1529 | { 1530 | "cell_type": "code", 1531 | "execution_count": 252, 1532 | "outputs": [], 1533 | "source": [ 1534 | "def plot_acc(X,Y):\n", 1535 | " plt.figure(figsize=(8,8))\n", 1536 | " plt.plot(X,Y)\n", 1537 | " plt.xlabel('different model')\n", 1538 | " plt.ylabel('Acc')\n", 1539 | " " 1540 | ], 1541 | "metadata": { 1542 | "collapsed": false, 1543 | "ExecuteTime": { 1544 | "end_time": "2024-01-26T08:57:48.576443200Z", 1545 | "start_time": "2024-01-26T08:57:48.560734100Z" 1546 | } 1547 | }, 1548 | "id": "81f535613cc08110" 1549 | }, 1550 | { 1551 | "cell_type": "code", 1552 | "execution_count": 253, 1553 | "outputs": [ 1554 | { 1555 | "data": { 1556 | "text/plain": "
", 1557 | "image/png": "" 1558 | }, 1559 | "metadata": {}, 1560 | "output_type": "display_data" 1561 | } 1562 | ], 1563 | "source": [ 1564 | "plot_acc(component_num,acc_list)" 1565 | ], 1566 | "metadata": { 1567 | "collapsed": false, 1568 | "ExecuteTime": { 1569 | "end_time": "2024-01-26T08:57:48.707894900Z", 1570 | "start_time": "2024-01-26T08:57:48.576443200Z" 1571 | } 1572 | }, 1573 | "id": "5d04d8d3f7653904" 1574 | } 1575 | ], 1576 | "metadata": { 1577 | "kernelspec": { 1578 | "display_name": "Python 3", 1579 | "language": "python", 1580 | "name": "python3" 1581 | }, 1582 | "language_info": { 1583 | "codemirror_mode": { 1584 | "name": "ipython", 1585 | "version": 2 1586 | }, 1587 | "file_extension": ".py", 1588 | "mimetype": "text/x-python", 1589 | "name": "python", 1590 | "nbconvert_exporter": "python", 1591 | "pygments_lexer": "ipython2", 1592 | "version": "2.7.6" 1593 | } 1594 | }, 1595 | "nbformat": 4, 1596 | "nbformat_minor": 5 1597 | } 1598 | -------------------------------------------------------------------------------- /New train folder/EEGNet.py: -------------------------------------------------------------------------------- 1 | # 导入工具包 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class EEGNet(nn.Module): 9 | def __init__(self, classes_num): 10 | super(EEGNet, self).__init__() 11 | self.drop_out = 0.25 12 | 13 | self.block_1 = nn.Sequential( 14 | # Pads the input tensor boundaries with zero 15 | # left, right, up, bottom 16 | nn.ZeroPad2d((31, 32, 0, 0)), 17 | nn.Conv2d( 18 | in_channels=1, # input shape (1, C, T) 19 | out_channels=8, # num_filters 20 | kernel_size=(1, 64), # filter size 21 | bias=False 22 | ), # output shape (8, C, T) 23 | nn.BatchNorm2d(8) # output shape (8, C, T) 24 | ) 25 | 26 | # block 2 and 3 are implementations of Depthwise Convolution and Separable Convolution 27 | self.block_2 = nn.Sequential( 28 | nn.Conv2d( 29 | in_channels=8, # input shape (8, C, T) 30 | out_channels=16, # num_filters 31 | kernel_size=(22, 1), # filter size 32 | # group8意味着八组滤波器 33 | groups=8, 34 | bias=False 35 | ), # output shape (16, 1, T) 36 | nn.BatchNorm2d(16), # output shape (16, 1, T) 37 | nn.ELU(), 38 | nn.AvgPool2d((1, 4)), # output shape (16, 1, T//4) 39 | nn.Dropout(self.drop_out) # output shape (16, 1, T//4) 40 | ) 41 | 42 | self.block_3 = nn.Sequential( 43 | nn.ZeroPad2d((8, 8, 0, 0)), 44 | nn.Conv2d( 45 | in_channels=16, # input shape (16, 1, T//4) 46 | out_channels=16, # num_filters 47 | kernel_size=(1, 16), # filter size 48 | # 十六组滤波器 49 | groups=16, 50 | bias=False 51 | ), # output shape (16, 1, T//4) 52 | nn.Conv2d( 53 | in_channels=16, # input shape (16, 1, T//4) 54 | out_channels=16, # num_filters 55 | kernel_size=(1, 1), # filter size 56 | bias=False 57 | ), # output shape (16, 1, T//4) 58 | nn.BatchNorm2d(16), # output shape (16, 1, T//4) 59 | nn.ELU(), 60 | nn.AvgPool2d((1, 8)), # output shape (16, 1, T//32) 61 | nn.Dropout(self.drop_out) 62 | ) 63 | 64 | self.out = nn.Linear((16 * 31), classes_num) 65 | 66 | def forward(self, x): 67 | x = self.block_1(x) 68 | # print("block1", x.shape) 69 | x = self.block_2(x) 70 | # print("block2", x.shape) 71 | x = self.block_3(x) 72 | # print("block3", x.shape) 73 | 74 | x = x.view(x.size(0), -1) 75 | x = self.out(x) 76 | 77 | # return F.softmax(x, dim=1), x # return x for visualization 78 | return x -------------------------------------------------------------------------------- /New train folder/Readme.md: -------------------------------------------------------------------------------- 1 | The new training method has been updated.This file folder contain some metrics and whole train and test processing. 2 | 3 | 4 | 5 | 6 | Modify the contents of the function in get_source_data to read your own data files.The preprocessing code is not publicly available, but thanks for understanding. 7 | 8 | 9 | 10 | 11 | If you have some questions, Please contact me. 12 | 13 | 14 | 15 | 16 | E-mail:asherxiong552@gmail.com 17 | -------------------------------------------------------------------------------- /New train folder/confusion_matrix1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XCZchaos/python-implementation-of-motion-imagination-classification/238c22858402277b51a80c7d9f6d03c17a9f5aa0/New train folder/confusion_matrix1.png -------------------------------------------------------------------------------- /New train folder/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from sklearn.metrics import confusion_matrix 5 | import itertools 6 | from scipy import stats 7 | from sklearn import manifold 8 | from einops import reduce 9 | from scipy.linalg import eigh 10 | 11 | def plot_confusion_matrix(y_true, y_pred, sub, title = "Confusion matrix - 2a", 12 | cmap=plt.cm.Blues, save_flg=True): 13 | 14 | y_pred = y_pred.cpu().detach().numpy() 15 | y_true = y_true.cpu().detach().numpy() 16 | classes = [str(i) for i in range(4)] 17 | labels = range(4) 18 | 19 | cm = confusion_matrix(y_true, y_pred, labels=labels, normalize='true') 20 | plt.figure(figsize=(14, 12)) 21 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 22 | plt.title(title, fontsize=40) 23 | plt.colorbar() 24 | tick_marks = np.arange(len(classes)) 25 | plt.xticks(tick_marks, classes, fontsize=20) 26 | plt.yticks(tick_marks, classes, fontsize=20) 27 | 28 | # print('Confusion matrix, without normalization') 29 | 30 | thresh = cm.max() / 2. 31 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 32 | plt.text(j, i, format(cm[i, j], '.2f'), 33 | horizontalalignment="center", 34 | color="white" if cm[i, j] > thresh else "black", 35 | fontsize=30) 36 | 37 | plt.ylabel('True label', fontsize=30) 38 | plt.xlabel('Predicted label', fontsize=30) 39 | # save your path,you can choose your path and change it 40 | if save_flg: 41 | plt.savefig("confusion_matrix" + str(sub) + ".png") 42 | 43 | 44 | 45 | def plt_tsne(data, label, per, nsub): 46 | 47 | data = data.cpu().detach().numpy() 48 | data = reduce(data, 'b n e -> b e', reduction='mean') 49 | label = label.cpu().detach().numpy() 50 | 51 | tsne = manifold.TSNE(n_components=2, perplexity=per, init='pca', random_state=166, learning_rate=200, n_iter=1000) 52 | X_tsne = tsne.fit_transform(data) 53 | 54 | x_min, x_max = X_tsne.min(0), X_tsne.max(0) 55 | X_norm = (X_tsne - x_min) / (x_max - x_min) 56 | plt.figure(figsize=(10, 8)) 57 | 58 | color_list = ['blue', 'red', 'green', 'orange'] 59 | 60 | unique_labels = np.unique(label) 61 | num_classes = len(unique_labels) 62 | 63 | label_to_color = {unique_labels[i]: color_list[i % len(color_list)] for i in range(num_classes)} 64 | 65 | for i in range(X_norm.shape[0]): 66 | plt.scatter(X_norm[i, 0], X_norm[i, 1], color=label_to_color[label[i]], s=50, alpha=0.8) # 增加点的大小和透明度 67 | 68 | plt.xticks([]) 69 | plt.yticks([]) 70 | plt.title('t-SNE visualization') 71 | 72 | plt.savefig('EEGNet_%d.png' % (nsub), dpi=600) 73 | 74 | 75 | 76 | def plot_metrics(train_losses, train_accuracies, nSub): 77 | epochs = range(1, len(train_losses) + 1) 78 | fig, ax1 = plt.subplots(figsize=(10, 5)) 79 | 80 | ax2 = ax1.twinx() 81 | ax1.plot(epochs, train_losses, 'g-', label='Training loss') 82 | ax2.plot(epochs, train_accuracies, 'r-', label='Training accuracy') 83 | 84 | ax1.set_xlabel('Epochs') 85 | ax1.set_ylabel('Loss', color='g') 86 | ax2.set_ylabel('Accuracy', color='r') 87 | 88 | 89 | lines1, labels1 = ax1.get_legend_handles_labels() 90 | lines2, labels2 = ax2.get_legend_handles_labels() 91 | 92 | 93 | combined_legend = ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right', bbox_to_anchor=(0.5, 0.5)) 94 | 95 | ax1.add_artist(combined_legend) # 添加合并的图例 96 | 97 | plt.title('Training Loss and Accuracy') 98 | plt.savefig('training_metrics_subject_%d.png' % (nSub)) 99 | plt.show() 100 | 101 | -------------------------------------------------------------------------------- /New train folder/read_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import mne\n", 10 | "import scipy\n", 11 | "import torch\n", 12 | "from torchsummary import summary\n", 13 | "from EEGNet import EEGNet" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "# 读取数据" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "(1000, 22, 288)\n", 33 | "(288, 1)\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "raw = scipy.io.loadmat(r'C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\BCICIV_2a_gdf\\A01T.mat')\n", 39 | "# 4x250, 22, 288 250为sample_rate 4为4秒的取值\n", 40 | "data = raw['data']\n", 41 | "label = raw['label']\n", 42 | "print(data.shape)\n", 43 | "print(label.shape)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "# 进行维度的转换" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "(288, 22, 1000)\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "data = data.transpose(2, 1, 0)\n", 68 | "print(data.shape)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "# 标准的EEGNet及其神经网络的输入" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 5, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/plain": [ 86 | "(288, 1, 22, 1000)" 87 | ] 88 | }, 89 | "execution_count": 5, 90 | "metadata": {}, 91 | "output_type": "execute_result" 92 | } 93 | ], 94 | "source": [ 95 | "data = data.reshape(data.shape[0], 1, data.shape[1], data.shape[2])\n", 96 | "data.shape" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "# 查看EEGNet的模型架构" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 6, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "----------------------------------------------------------------\n", 116 | " Layer (type) Output Shape Param #\n", 117 | "================================================================\n", 118 | " ZeroPad2d-1 [48, 1, 22, 1063] 0\n", 119 | " Conv2d-2 [48, 8, 22, 1000] 512\n", 120 | " BatchNorm2d-3 [48, 8, 22, 1000] 16\n", 121 | " Conv2d-4 [48, 16, 1, 1000] 352\n", 122 | " BatchNorm2d-5 [48, 16, 1, 1000] 32\n", 123 | " ELU-6 [48, 16, 1, 1000] 0\n", 124 | " AvgPool2d-7 [48, 16, 1, 250] 0\n", 125 | " Dropout-8 [48, 16, 1, 250] 0\n", 126 | " ZeroPad2d-9 [48, 16, 1, 266] 0\n", 127 | " Conv2d-10 [48, 16, 1, 251] 256\n", 128 | " Conv2d-11 [48, 16, 1, 251] 256\n", 129 | " BatchNorm2d-12 [48, 16, 1, 251] 32\n", 130 | " ELU-13 [48, 16, 1, 251] 0\n", 131 | " AvgPool2d-14 [48, 16, 1, 31] 0\n", 132 | " Dropout-15 [48, 16, 1, 31] 0\n", 133 | " Linear-16 [48, 4] 1,988\n", 134 | "================================================================\n", 135 | "Total params: 3,444\n", 136 | "Trainable params: 3,444\n", 137 | "Non-trainable params: 0\n", 138 | "----------------------------------------------------------------\n", 139 | "Input size (MB): 4.03\n", 140 | "Forward/backward pass size (MB): 165.78\n", 141 | "Params size (MB): 0.01\n", 142 | "Estimated Total Size (MB): 169.83\n", 143 | "----------------------------------------------------------------\n" 144 | ] 145 | } 146 | ], 147 | "source": [ 148 | "model = EEGNet(4).to('cuda')\n", 149 | "summary(input_size=(1, 22, 1000), batch_size=48, device='cuda', model=model)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 7, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "name": "stdout", 159 | "output_type": "stream", 160 | "text": [ 161 | "Extracting EDF parameters from C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\BCICIV_2a_gdf\\A01T.gdf...\n", 162 | "GDF file detected\n", 163 | "Setting channel info structure...\n", 164 | "Could not determine channel type of the following channels, they will be set as EEG:\n", 165 | "EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right\n", 166 | "Creating raw.info structure...\n" 167 | ] 168 | }, 169 | { 170 | "name": "stderr", 171 | "output_type": "stream", 172 | "text": [ 173 | "E:\\Anaconda3\\lib\\contextlib.py:126: RuntimeWarning: Channel names are not unique, found duplicates for: {'EEG'}. Applying running numbers for duplicates.\n", 174 | " next(self.gen)\n" 175 | ] 176 | }, 177 | { 178 | "name": "stdout", 179 | "output_type": "stream", 180 | "text": [ 181 | "Reading 0 ... 672527 = 0.000 ... 2690.108 secs...\n" 182 | ] 183 | }, 184 | { 185 | "data": { 186 | "text/html": [ 187 | "
\n", 188 | " General\n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | "
Measurement dateJanuary 17, 2005 12:00:00 GMT
ExperimenterUnknown
ParticipantA01
\n", 211 | "
\n", 212 | "
\n", 213 | " Channels\n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | "
Digitized pointsNot available
Good channels25 EEG
Bad channelsNone
EOG channelsNot available
ECG channelsNot available
\n", 238 | "
\n", 239 | "
\n", 240 | " Data\n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | "
Sampling frequency250.00 Hz
Highpass0.50 Hz
Lowpass100.00 Hz
\n", 264 | "
" 265 | ], 266 | "text/plain": [ 267 | "" 280 | ] 281 | }, 282 | "execution_count": 7, 283 | "metadata": {}, 284 | "output_type": "execute_result" 285 | } 286 | ], 287 | "source": [ 288 | "raw = mne.io.read_raw_gdf(r'C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\BCICIV_2a_gdf\\A01T.gdf', preload=True)\n", 289 | "raw.info" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 10, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "data": { 299 | "text/html": [ 300 | "
\n", 301 | " General\n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | "
Measurement dateJanuary 17, 2005 12:00:00 GMT
ExperimenterUnknown
ParticipantA01
\n", 324 | "
\n", 325 | "
\n", 326 | " Channels\n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | "
Digitized pointsNot available
Good channels25 EEG
Bad channelsNone
EOG channelsNot available
ECG channelsNot available
\n", 351 | "
\n", 352 | "
\n", 353 | " Data\n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | "
Sampling frequency250.00 Hz
Highpass0.50 Hz
Lowpass100.00 Hz
FilenamesA01T.gdf
Duration00:44:51 (HH:MM:SS)
\n", 387 | "
" 388 | ], 389 | "text/plain": [ 390 | "" 391 | ] 392 | }, 393 | "execution_count": 10, 394 | "metadata": {}, 395 | "output_type": "execute_result" 396 | }, 397 | { 398 | "ename": "", 399 | "evalue": "", 400 | "output_type": "error", 401 | "traceback": [ 402 | "\u001b[1;31m在当前单元格或上一个单元格中执行代码时 Kernel 崩溃。\n", 403 | "\u001b[1;31m请查看单元格中的代码,以确定故障的可能原因。\n", 404 | "\u001b[1;31m单击此处了解详细信息。\n", 405 | "\u001b[1;31m有关更多详细信息,请查看 Jupyter log。" 406 | ] 407 | } 408 | ], 409 | "source": [ 410 | "raw" 411 | ] 412 | } 413 | ], 414 | "metadata": { 415 | "kernelspec": { 416 | "display_name": ".venv", 417 | "language": "python", 418 | "name": "python3" 419 | }, 420 | "language_info": { 421 | "codemirror_mode": { 422 | "name": "ipython", 423 | "version": 3 424 | }, 425 | "file_extension": ".py", 426 | "mimetype": "text/x-python", 427 | "name": "python", 428 | "nbconvert_exporter": "python", 429 | "pygments_lexer": "ipython3", 430 | "version": "3.9.7" 431 | } 432 | }, 433 | "nbformat": 4, 434 | "nbformat_minor": 2 435 | } 436 | -------------------------------------------------------------------------------- /New train folder/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from EEGNet import EEGNet 5 | import scipy 6 | from torch.autograd import Variable 7 | from sklearn.metrics import cohen_kappa_score 8 | from metrics import plot_confusion_matrix, plot_metrics 9 | import random 10 | 11 | gpus = [0] 12 | 13 | # 输入的shape 14 | # (288, 1, 22, 1000) (trial, cov_number, channel, timepiont) (batch_size, channel, width, height) 15 | # (batch_size, RGB, height, width) 16 | class Trans(): 17 | def __init__(self, nsub): 18 | super(Trans, self).__init__() 19 | self.batch_size = 50 20 | self.n_epochs = 1000 21 | self.lr = 0.0002 22 | self.b1 = 0.5 23 | self.b2 = 0.9 24 | self.nSub = nsub 25 | self.start_epoch = 0 26 | self.root = 'C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\BCICIV_2a_gdf\\' # the path of data 27 | self.pretrain = False 28 | self.Tensor = torch.cuda.FloatTensor 29 | self.LongTensor = torch.cuda.LongTensor 30 | self.model = EEGNet(4).cuda() 31 | self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))]) 32 | self.model = self.model.cuda() 33 | self.centers = {} 34 | self.criterion = nn.CrossEntropyLoss() 35 | 36 | def get_source_data(self): 37 | self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub) 38 | self.train_data = self.total_data['data'] 39 | self.train_label = self.total_data['label'] 40 | self.train_data = np.transpose(self.train_data, (2, 1, 0)) 41 | self.train_data = np.expand_dims(self.train_data, axis=1) # (288, 1, 22, 1000) 42 | self.train_label = np.transpose(self.train_label) 43 | self.allData = self.train_data 44 | self.allLabel = self.train_label[0] 45 | self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub) 46 | self.test_data = self.test_tmp['data'] 47 | self.test_label = self.test_tmp['label'] 48 | self.test_data = np.transpose(self.test_data, (2, 1, 0)) 49 | self.test_data = np.expand_dims(self.test_data, axis=1) 50 | self.test_label = np.transpose(self.test_label) 51 | self.testData = self.test_data 52 | self.testLabel = self.test_label[0] 53 | # 归一化 54 | target_mean = np.mean(self.allData) 55 | target_std = np.std(self.allData) 56 | self.allData = (self.allData - target_mean) / target_std 57 | self.testData = (self.testData - target_mean) / target_std 58 | return self.allData, self.allLabel, self.testData, self.testLabel 59 | 60 | def update_lr(self, optimizer, lr): 61 | for param_group in optimizer.param_groups: 62 | param_group['lr'] = lr 63 | 64 | def calculate_kappa(self, y_true, y_pred): 65 | kappa = cohen_kappa_score(y_true.cpu().numpy(), y_pred.cpu().numpy()) 66 | return kappa 67 | 68 | def interaug(self, timg, label): 69 | aug_data = [] 70 | aug_label = [] 71 | for cls4aug in range(4): 72 | cls_idx = np.where(label == cls4aug + 1) 73 | tmp_data = timg[cls_idx] 74 | tmp_label = label[cls_idx] 75 | tmp_aug_data = np.zeros((int(self.batch_size / 4), 1, 22, 1000)) 76 | for ri in range(int(self.batch_size / 4)): 77 | for rj in range(8): 78 | rand_idx = np.random.randint(0, tmp_data.shape[0], 8) 79 | tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, rj * 125:(rj + 1) * 125] 80 | aug_data.append(tmp_aug_data) 81 | aug_label.append(tmp_label[:int(self.batch_size / 4)]) 82 | aug_data = np.concatenate(aug_data) 83 | aug_label = np.concatenate(aug_label) 84 | aug_shuffle = np.random.permutation(len(aug_data)) 85 | aug_data = aug_data[aug_shuffle, :, :] 86 | aug_label = aug_label[aug_shuffle] 87 | aug_data = torch.from_numpy(aug_data).cuda() 88 | aug_data = aug_data.float() 89 | aug_label = torch.from_numpy(aug_label - 1).cuda() 90 | aug_label = aug_label.long() 91 | return aug_data, aug_label 92 | 93 | def train(self): 94 | img, label, test_data, test_label = self.get_source_data() 95 | img = torch.from_numpy(img) 96 | label = torch.from_numpy(label - 1) 97 | dataset = torch.utils.data.TensorDataset(img, label) 98 | self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True) 99 | test_data = torch.from_numpy(test_data) 100 | test_label = torch.from_numpy(test_label - 1) 101 | test_dataset = torch.utils.data.TensorDataset(test_data, test_label) 102 | self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True) 103 | test_data = Variable(test_data.type(self.Tensor)) 104 | test_label = Variable(test_label.type(self.LongTensor)) 105 | bestAcc = 0 106 | averAcc = 0 107 | num = 0 108 | Y_true = 0 109 | Y_pred = 0 110 | 111 | train_losses = [] 112 | train_accuracies = [] 113 | 114 | for e in range(self.n_epochs): 115 | self.model.train() 116 | for i, (img, label) in enumerate(self.dataloader): 117 | img = Variable(img.cuda().type(self.Tensor)) 118 | label = Variable(label.cuda().type(self.LongTensor)) 119 | aug_data, aug_label = self.interaug(self.allData, self.allLabel) 120 | img = torch.cat((img, aug_data)) 121 | label = torch.cat((label, aug_label)) 122 | outputs = self.model(img) 123 | print(label.shape) 124 | loss = self.criterion(outputs, label) 125 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2)) 126 | self.optimizer.zero_grad() 127 | loss.backward() 128 | self.optimizer.step() 129 | 130 | 131 | 132 | 133 | if (e + 1) % 1 == 0: 134 | self.model.eval() 135 | outputs_test= self.model(test_data) 136 | loss_test = self.criterion(outputs_test, test_label) 137 | y_pred = torch.max(outputs_test, 1)[1] 138 | acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0)) 139 | train_pred = torch.max(outputs, 1)[1] 140 | train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) 141 | print('Epoch:', e, 142 | ' Train loss:', loss.detach().cpu().numpy(), 143 | ' Test loss:', loss_test.detach().cpu().numpy(), 144 | ' Train accuracy:', train_acc, 145 | ' Test accuracy is:', acc) 146 | train_losses.append(loss.detach().cpu().numpy()) 147 | train_accuracies.append(train_acc) 148 | num = num + 1 149 | averAcc = averAcc + acc 150 | if acc > bestAcc: 151 | bestAcc = acc 152 | Y_true = test_label 153 | Y_pred = y_pred 154 | 155 | 156 | 157 | # if e == self.n_epochs - 1: 158 | 159 | # plt_tsne(outputs_test, test_label, per=30, nsub=self.nSub) 160 | # plt_tsne(test_label, per=30, nsub=self.nSub) 161 | 162 | # you can save the model state_dict in your path 163 | # torch.save(self.model.module.state_dict(), '/root/autodl-tmp/model_picture/TFCformer_model_Subject_%d.pth' % (self.nSub)) 164 | averAcc = averAcc / num 165 | print('The average accuracy is:', averAcc) 166 | print('The best accuracy is:', bestAcc) 167 | kappa = self.calculate_kappa(Y_true, Y_pred) 168 | print('The kappa score is:', kappa) 169 | plot_metrics(train_losses, train_accuracies, self.nSub) 170 | return bestAcc, averAcc, Y_true, Y_pred, kappa 171 | 172 | 173 | 174 | def main(): 175 | best = 0 176 | aver = 0 177 | for i in range(9): 178 | seed_n = np.random.randint(500) 179 | print('seed is ' + str(seed_n)) 180 | random.seed(seed_n) 181 | np.random.seed(seed_n) 182 | torch.manual_seed(seed_n) 183 | torch.cuda.manual_seed(seed_n) 184 | torch.cuda.manual_seed_all(seed_n) 185 | print('Subject %d' % (i+1)) 186 | trans = Trans(i + 1) 187 | bestAcc, averAcc, Y_true, Y_pred, kappa = trans.train() 188 | print('THE BEST ACCURACY IS ' + str(bestAcc)) 189 | plot_confusion_matrix(Y_true, Y_pred, i+1) 190 | best = best + bestAcc 191 | aver = aver + averAcc 192 | if i == 0: 193 | yt = Y_true 194 | yp = Y_pred 195 | else: 196 | yt = torch.cat((yt, Y_true)) 197 | yp = torch.cat((yp, Y_pred)) 198 | plot_confusion_matrix(yt, yp, 666) 199 | 200 | 201 | if __name__ == '__main__': 202 | 203 | main() -------------------------------------------------------------------------------- /New train folder/training_metrics_subject_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XCZchaos/python-implementation-of-motion-imagination-classification/238c22858402277b51a80c7d9f6d03c17a9f5aa0/New train folder/training_metrics_subject_1.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # python-implementation-of-motion-imagination-classification 2 | This repository corresponds to the code for my blog 3 | 4 | 5 | 6 | 7 | And some of my notes from BCI learning are uploaded here, but the notes are in Chinese 8 | 9 | Thank you for your interest in this repository 10 | 11 | 12 | 13 | If you have some questions, Please contact me. 14 | 15 | 16 | E-mail:asherxiong552@gmail.com 17 | 18 | 19 | 20 | 21 | updata log: 22 | Updated the New folder 'New train fold',If you want to learn more,please read the file 'Readme.md' in 'New train folder' 23 | --------Edited on 2024/8/14 24 | -------------------------------------------------------------------------------- /mne包教程.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XCZchaos/python-implementation-of-motion-imagination-classification/238c22858402277b51a80c7d9f6d03c17a9f5aa0/mne包教程.pdf -------------------------------------------------------------------------------- /脑机接口导论笔记.md: -------------------------------------------------------------------------------- 1 | ### 论文知识 2 | 3 | 我们在进行运动想象事件当中,我们一般都使用C3 C4两个通道,Cz通道其实对运动想象的结果并没有那么大,C3和C4、Cz三者组成的是大脑的感觉运动皮层,而β节律是主要运动想象所需的节律 4 | 5 | ![img](https://pic2.zhimg.com/80/v2-9956a285c0e96767ba12912178af99d1_720w.webp) 6 | 7 | 8 | 9 | 10 | 11 | 想象肢体运动时,对应皮层投射区域出现脑电节律调制现象。当想象左手运动时,大脑右半球C4区域神经电活动增强,该区域的信息加工导致EEG功率谱出现减弱现象;当想象右手运动时,大脑左半球C3区域神经电活动增强,其EEG功率谱出现减弱现象;而对应脚和舌头的ERD/ERS现象分别出现在大脑顶叶和颞叶,大致为Cz和CP6区域。不同运动想象任务时所激活的脑区如下图1所示,其空间分布符合周围神经纤维与大脑皮层投射关系,与脑功能分区图相一致,因此,运动想象脑电信号具有空间特性。 12 | 13 | ERD/ERS现象主要出现在频率范围在8-12Hz的mu节律和频率范围在18-26Hz的beta节律,特别是mu节律变化最为显著。但是不同任务的运动想象脑电在特征频段上也有差别,对应手的ERD经常出现在10-12Hz和20-24Hz,对应脚的ERD经常出现在7-8Hz和20-24Hz,对应舌头的ERS经常出现在10-11Hz,并且运动想象脑电频段是与特定人相关的,因此,运动想象脑电信号还具有频段特性,合理选择最佳的滤波器及滤波频段是后续处理提高分类效果的关键。 14 | 15 | ### 第三章、记录大脑信号和刺激大脑 16 | 17 | EEG信号反应的是上千个神经元产生的突出后点位的总和。EEG信号通常会被其他行为干扰,比如眼动等与实验无关行为产生的点位,还有一些设备因素,比如工频干扰(国内50HZ),我们一般使用凹陷滤波进行去除工频干扰。 18 | 19 | 记住几个比较特殊的电极位置,比如挂在耳垂的乳突(A1,A2)以及中间的电极CZ和CZ两侧的电极C3和C4,这对运动想象实验来说比较重要 20 | 21 | 在测量EEG信号时,可以测量一对电极之间的电位差。测量各个电极的电位和一个中性电极或者所有电极的平均值(共同平均参考或者CAR) 22 | 23 | 当我们进行运动想象时,alpha节律(mu节律,范围为8~13HZ)会减少,而beta节律(范围为13~30HZ)会增强。 24 | 25 | ### 第四章、信号处理 26 | 27 | #### 1.锋电位分类 28 | 29 | 我们看脑电图像时,会看到很多条周期函数混合在一起,因为电极记录的是局部神经元进行的放电,并把其局部所有神经元的峰值电位混合在了一起。 30 | 31 | 锋电位分类在信号处理时,能够可靠地分离和提取每个记录电极采集的由单个神经元发出的锋电位 32 | 33 | #### 2.频域分析 34 | 35 | 傅里叶变换将一个区间[-T/2,T/2]中的事件信号t(s)分解为无限多个正弦和余弦函数加权和 36 | 37 | 正弦和余弦波可以视为基本函数,将这些函数以不同的加权,从而可以得到不同的信号,这一过程即为信号合成,其权值可以由输入的信号计算出来 38 | 39 | 傅里叶变换运用于滤波当中,如利用凹陷滤波去除工频干扰,将输入信号进行重构,使得重构后的信号不包含该特征。权值的求解是对时域函数*特征频率的正弦或者余弦函数在[-T/2,T/2]上进行积分 40 | 41 | #### 3.频谱特征 42 | 43 | 从一段时间间隔内大脑信号的功率谱中提取特征,在进行运动想象时,beta节律会增强,alpha节律会减少。具体的频谱怎么用,参照mne包中的笔记 44 | 45 | #### 4.小波分析 46 | 47 | 由于EEG等大脑信号是非平稳信号,所以在短时窗内进行傅里叶分析,这种方法被称为短时傅里叶变换。小波变换的基函数不再是正弦和余弦函数,小波函数是由一个有限长的母小波通过伸缩和平移得到,小波函数有限的长度使其可以用来表示非周期信号或是有陡变不连续的信号。 48 | 49 | 小波变换也是用基函数的线性组合来表示原信号 50 | 51 | #### 5.时域分析 52 | 53 | ##### 分形维数 54 | 55 | 如果一个信号表现出自相似性,则认为它分形的,信号的一部分与整个信号都具有相似性。分形维数是对这种相似性的定量测量。EEG的大脑信号的分形维数一般在1.4和1.7之间,值越大表明发生了高锋电位的活动,例如癫痫。 56 | 57 | ##### 自回归模型 58 | 59 | 自然信号在时间或者空间等其他维度上有着相关的趋势,因此常常能用之前的一些测量值来预测下一个测量值。贝叶斯滤波和卡尔曼滤波也是该递归原理,利用上一项计算下一项 60 | 61 | #### 6.空间滤波 62 | 63 | 空间滤波将不同位置记录大脑信号通过几种方式进行信号转换,目的是增强局部活动、减弱各通道中的共有噪声、降低噪声维数,识别隐含的源,找到能最大程度区分不同类别的投影 64 | 65 | 双极信号通过计算两极之间的电位差来进行处理 66 | 67 | 拉普拉斯滤波是将电极减去四个正交的最邻近电极信号的平均值,共同平均参考则是减去所有电极的平均值 68 | 69 | ##### 主成分分析 70 | 71 | 如果有64位通道、一次实验中可获取N个64维向量组成的数据集,所以在一次实验中会获得N*64个数据 72 | 73 | PCA会选找到L位数据中方差最大的方向,将原坐标旋转变换到最大方差的方向上,如果原数据是冗余的,并且保留一些方差大的方向,那么就可以丢弃方差小的方向上的坐标,使得数据维度降低 74 | 75 | 大多数自然自然信号,从多个位置记录到大脑信号都可能是冗余的,因此可以进行降维操作,PCA可以利用这些冗余,试图去寻找数据差异性的主方向。如果找到了这些低位子空间的主方向,那么就可以将原数据进行主方向上投影,从而达到降维效果,得到M维向量就可以用做分类任务的特征向量 76 | 77 | 复习以下协方差矩阵的概念以及拉格朗日乘子法来计算特征值 78 | 79 | ##### 独立分量分析 80 | 81 | PCA能确保协方差为零,但是无法确保两个随机变量具有高阶独立性,在高阶时,两个变量的协方差不一定为零,比如在平方时可能不为零导致变量不独立 82 | 83 | 地形图判断严重的原理是,对前额和眼睛附近的主成分进行正值和负值的加权从而得到区分度 84 | 85 | ICA与PCA的不同在W矩阵中的行向量不再需要满足正交的条件,而且PCA中向量a的维度小于且最多等于输入x的维数,而ICA特征向量的维数可以小于、等于或者大于输入的维数 86 | 87 | ##### 共空间模式 88 | 89 | PCA和ICA都是无监督学习,而CSP是一种监督方法,每个训练数据都是被标记过的,每个数据向量类型都是已知 90 | 91 | CSP寻找空间滤波器,使滤波处理后的数据与其中一类的方差达到最大,而与另一类的方差达到最小,因而,得到的特征向量增强了两类之间的差别 92 | 93 | CSP是实质上使得BCI所使用特征的可区分程度最大化 94 | 95 | --------------------------------------------------------------------------------