├── requirements.txt ├── utils ├── __init__.py ├── LogRecord.py ├── CsvRecord.py ├── data_augment.py ├── utils.py ├── func_utils.py ├── network.py ├── loss.py └── dataloader.py ├── .gitignore ├── README.md ├── EEG_cross_subject_loader.py ├── Baseline.py ├── models ├── ShallowConvNet.py ├── EEGNet.py └── EEGNetv4.py ├── MAML.py ├── ANIL.py ├── MDMAML.py ├── test_model.py ├── ProtoNets.py └── SHOT.py /requirements.txt: -------------------------------------------------------------------------------- 1 | learn2learn==0.1.6 2 | numpy==1.21.3 3 | pandas==1.3.4 4 | scikit_learn==1.1.3 5 | scipy==1.7.1 6 | torch==1.10.0 7 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/1/11 11:43 下午 3 | # @Author : wenzhang 4 | # @File : __init__.py.py 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | .vscode 4 | .mypy_cache 5 | *.py[cod] 6 | *.sw[px] 7 | *$py.class 8 | *.csd 9 | .idea 10 | 11 | data 12 | runs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Learning for Fast and Privacy-Preserving Calibration of EEG-based BCIs 2 | The official implementation of our paper [`Meta-Learning for Fast and Privacy-Preserving Source Knowledge Transfer of EEG-based BCIs`](https://ieeexplore.ieee.org/document/9942685) (**IEEE CIM, 2022**) 3 | 4 | Please contact me at syoungli@hust.edu.cn or lsyyoungll@gmail.com for any questions regarding the paper, and use Issues for any questions regarding the code. 5 | 6 | If you have questions with downloading EEG data, do check out [DeepTransferEEG repo](https://github.com/sylyoung/DeepTransferEEG) for easier usage of public datasets. -------------------------------------------------------------------------------- /utils/LogRecord.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/12/8 19:08 3 | # @Author : wenzhang 4 | # @File : LogRecord.py 5 | 6 | import torch as tr 7 | import os.path as osp 8 | from datetime import datetime 9 | from datetime import timedelta, timezone 10 | 11 | from utils.utils import create_folder 12 | 13 | 14 | class LogRecord: 15 | def __init__(self, args): 16 | self.args = args 17 | self.result_dir = args.result_dir 18 | self.data_env = 'gpu' if tr.cuda.get_device_name(0) != 'GeForce GTX 1660 Ti' else 'local' 19 | self.data_name = args.data 20 | self.method = args.method 21 | 22 | def log_init(self): 23 | create_folder(self.result_dir, self.args.data_env, self.args.local_dir) 24 | 25 | if self.data_env in ['local', 'mac']: 26 | time_str = datetime.utcnow().replace(tzinfo=timezone.utc).astimezone( 27 | timezone(timedelta(hours=8), name='Asia/Shanghai')).strftime("%Y-%m-%d_%H_%M_%S") 28 | if self.data_env == 'gpu': 29 | time_str = datetime.utcnow().replace(tzinfo=timezone.utc).strftime("%Y-%m-%d_%H_%M_%S") 30 | file_name_head = 'log_' + self.method + '_' + self.data_name + '_' 31 | self.args.out_file = open(osp.join(self.args.result_dir, file_name_head + time_str + '.txt'), 'w') 32 | self.args.out_file.write(self._print_args() + '\n') 33 | self.args.out_file.flush() 34 | return self.args 35 | 36 | def record(self, log_str): 37 | self.args.out_file.write(log_str + '\n') 38 | self.args.out_file.flush() 39 | return self.args 40 | 41 | def _print_args(self): 42 | s = "==========================================\n" 43 | for arg, content in self.args.__dict__.items(): 44 | s += "{}:{}\n".format(arg, content) 45 | return s 46 | -------------------------------------------------------------------------------- /EEG_cross_subject_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | 4 | 5 | class EEG_loader(): 6 | 7 | def __init__(self, test_subj=None, dataset=None): 8 | 9 | test_subj = test_subj 10 | data_folder = './data/' + str(dataset) 11 | 12 | train_x_arr = [] 13 | train_y_arr = [] 14 | 15 | prefix = 's' 16 | 17 | mat = sio.loadmat(data_folder + "/" + prefix + str(test_subj) + ".mat") 18 | x = np.moveaxis(np.array(mat['x']), -1, 0) 19 | y = np.array(mat['y']) 20 | test_x = x 21 | test_y = y 22 | 23 | a = 0 24 | if dataset == 'MI1':# 9 subjects 25 | k = 9 26 | elif dataset == 'MI2':# 14 subjects 27 | k = 14 28 | elif dataset == 'ERP1':# 10 subjects 29 | k = 10 30 | elif dataset == 'ERP2':# 16 subjects 31 | k = 16 32 | 33 | for i in range(a, k): 34 | 35 | mat = sio.loadmat(data_folder + "/" + prefix + str(i) + ".mat") 36 | x = np.moveaxis(np.array(mat['x']), -1, 0) 37 | y = np.array(mat['y']) 38 | 39 | train_x_arr.append(x) 40 | train_y_arr.append(y) 41 | 42 | train_x_array_out = [] 43 | train_y_array_out = [] 44 | for train_x, train_y in zip(train_x_arr, train_y_arr): 45 | 46 | np.random.seed(42) 47 | idx = list(range(len(train_y))) 48 | np.random.shuffle(idx) 49 | train_x = train_x[idx] 50 | train_y = train_y[idx] 51 | 52 | train_x_array_out.append(train_x) 53 | train_y_array_out.append(train_y) 54 | 55 | idx = list(range(len(test_y))) 56 | np.random.shuffle(idx) 57 | test_x = test_x[idx] 58 | test_y = test_y[idx] 59 | 60 | self.train_x = train_x_array_out 61 | self.train_y = train_y_array_out 62 | self.test_x = test_x 63 | self.test_y = test_y 64 | -------------------------------------------------------------------------------- /utils/CsvRecord.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/1/12 12:44 下午 3 | # @Author : wenzhang 4 | # @File : CsvRecord.py 5 | import pandas as pd 6 | import numpy as np 7 | import os 8 | import csv 9 | from datetime import datetime 10 | from datetime import timedelta, timezone 11 | 12 | 13 | class CsvRecord: 14 | def __init__(self, args): 15 | self.data_env = args.data_env 16 | self.data_str = args.data 17 | self.N = args.N 18 | self.file_str = args.file_str 19 | 20 | def init(self): 21 | name_list = ['file', 'data', 'time'] + [str(i + 1) for i in range(self.N)] + ['Avg', 'Std'] 22 | 23 | acc_str_list = ['-' for _ in range(self.N)] 24 | if self.data_env == 'local': 25 | self.time_str = datetime.utcnow().replace(tzinfo=timezone.utc).astimezone( 26 | timezone(timedelta(hours=8), name='Asia/Shanghai')).strftime("%m-%d_%H_%M_%S") 27 | if self.data_env == 'gpu': 28 | self.time_str = datetime.utcnow().replace(tzinfo=timezone.utc).strftime("%m-%d_%H_%M_%S") 29 | output_str_row = np.array([self.file_str, self.data_str, self.time_str] + acc_str_list + ['-', '-']) 30 | output_pd = pd.DataFrame(dict(zip(name_list, output_str_row.T)), index=[0]) 31 | 32 | # 检测是否存在该文件,如果存在则不init 33 | self.save_path = './csv/acc_log_' + self.data_str + '.csv' 34 | if not os.path.exists(self.save_path): 35 | output_pd.to_csv(self.save_path, index=None) 36 | 37 | def record(self, acc_array_raw): 38 | acc_str_list = [str(i) for i in np.round(acc_array_raw, 2)] 39 | mean_acc = np.round(np.mean(acc_array_raw), 2) 40 | std_acc = np.round(np.std(acc_array_raw), 2) 41 | output_str_row = np.array( 42 | [self.file_str, self.data_str, self.time_str] + acc_str_list + [str(mean_acc), str(std_acc)]) 43 | with open(self.save_path, mode='a', newline='', encoding='utf8') as cfa: 44 | csv.writer(cfa).writerow(output_str_row) 45 | 46 | 47 | if __name__ == '__main__': 48 | import argparse 49 | 50 | args = argparse.Namespace(data_env='local', data='MI2-4') 51 | args.N = 9 52 | args.file_str = 'demo_test' 53 | 54 | csv_log = CsvRecord(args) 55 | csv_log.init() 56 | 57 | sub_acc_all = np.array([69.097, 29.167, 80.208, 41.319, 38.889, 36.111, 64.583, 74.306, 67.014]) 58 | csv_log.record(sub_acc_all) 59 | -------------------------------------------------------------------------------- /Baseline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import TensorDataset, DataLoader 4 | 5 | from models.EEGNet import EEGNet 6 | from models.ShallowConvNet import ShallowConvNet, ShallowConvNetReduced 7 | from EEG_cross_subject_loader import EEG_loader 8 | 9 | import random 10 | 11 | 12 | def main( 13 | test_subj=None, 14 | learning_rate=None, 15 | num_iterations=None, 16 | cuda=None, 17 | seed=None, 18 | dataset=None, 19 | model_name=None, 20 | save=False, 21 | ): 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | device = torch.device('cpu') 26 | if cuda: 27 | torch.cuda.manual_seed(seed) 28 | device = torch.device('cuda:2') 29 | print('using cuda...') 30 | 31 | data = EEG_loader(test_subj=test_subj, dataset=dataset) 32 | train_x_arr = data.train_x 33 | train_y_arr = data.train_y 34 | train_x_arr_tmp = [] 35 | train_y_arr_tmp = [] 36 | 37 | for train_x, train_y in zip(train_x_arr, train_y_arr): 38 | train_x_arr_tmp.append(train_x) 39 | train_y_arr_tmp.append(train_y) 40 | 41 | tensor_train_x, tensor_train_y = torch.from_numpy(train_x).unsqueeze_(3).to( 42 | torch.float32), torch.squeeze(torch.from_numpy(train_y), 1).to(torch.long) 43 | 44 | train_dataset = TensorDataset(tensor_train_x, tensor_train_y) 45 | train_loader = DataLoader(train_dataset, batch_size=64) 46 | 47 | del data, train_x, train_y 48 | 49 | if model_name == 'ShallowConvNet': 50 | if dataset == 'MI1': 51 | model = ShallowConvNet(4, 22, 16640) 52 | if dataset == 'MI2': 53 | model = ShallowConvNet(2, 15, 26520) 54 | if dataset == 'ERP1': 55 | model = ShallowConvNetReduced(2, 16, 6760) 56 | if dataset == 'ERP2': 57 | model = ShallowConvNet(2, 56, 17160) 58 | elif model_name == 'EEGNet': 59 | if dataset == 'MI1': 60 | model = EEGNet(22, 256, 4) 61 | if dataset == 'MI2': 62 | model = EEGNet(15, 384, 2) 63 | if dataset == 'ERP1': 64 | model = EEGNet(16, 32, 2) 65 | if dataset == 'ERP2': 66 | model = EEGNet(56, 256, 2) 67 | 68 | model.to(device) 69 | 70 | opt = torch.optim.Adam(model.parameters(), lr=learning_rate) 71 | 72 | if dataset == 'ERP1': 73 | class_weight = torch.tensor([1., 4.99], dtype=torch.float32).to(device) 74 | criterion = torch.nn.CrossEntropyLoss(weight=class_weight) 75 | elif dataset == 'ERP2': 76 | class_weight = torch.tensor([1., 2.42], dtype=torch.float32).to(device) 77 | criterion = torch.nn.CrossEntropyLoss(weight=class_weight) 78 | else: 79 | criterion = torch.nn.CrossEntropyLoss() 80 | 81 | # Train the model 82 | for epoch in range(num_iterations): 83 | model.train() 84 | # print('epoch:', epoch + 1) 85 | total_loss = 0 86 | cnt = 0 87 | for i, (x, y) in enumerate(train_loader): 88 | # Forward pass 89 | x = x.to(device) 90 | y = y.to(device) 91 | 92 | outputs = model(x) 93 | loss = criterion(outputs, y) 94 | total_loss += loss 95 | cnt += 1 96 | 97 | # Backward and optimize 98 | opt.zero_grad() 99 | loss.backward() 100 | opt.step() 101 | out_loss = total_loss / cnt 102 | 103 | print('Epoch [{}/{}], Loss: {:.4f}' 104 | .format(epoch + 1, num_iterations, out_loss)) 105 | 106 | if (epoch + 1) % 50 == 0 and epoch != 0 and save: 107 | # Save the model checkpoint 108 | torch.save(model, './runs/' + str(dataset) + '/baseline_' + model_name + dataset + '_seed' + str( 109 | seed) + '_test_subj_' + str(test_subj) + '_epoch' + str(epoch + 1) + '.pt') 110 | 111 | 112 | if __name__ == '__main__': 113 | 114 | lr = 0.001 115 | num_iterations = 100 116 | for model_name in ['EEGNet', 'ShallowConvNet']: 117 | for dataset in ['MI1', 'MI2', 'ERP1', 'ERP2']: 118 | if dataset == 'MI1': 119 | subj_num = 9 120 | elif dataset == 'MI2': 121 | subj_num = 14 122 | elif dataset == 'ERP1': 123 | subj_num = 10 124 | elif dataset == 'ERP2': 125 | subj_num = 16 126 | 127 | for test_subj in range(0, subj_num): 128 | for seed in range(0, 10): 129 | print('Baseline', model_name, dataset) 130 | print('subj', test_subj, 'seed', seed) 131 | main(test_subj=test_subj, 132 | learning_rate=lr, 133 | num_iterations=num_iterations, 134 | cuda=True, 135 | seed=42, 136 | dataset=dataset, 137 | model_name=model_name, 138 | save=True, 139 | ) 140 | -------------------------------------------------------------------------------- /models/ShallowConvNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ShallowConvNet(nn.Module): 6 | def __init__(self, n_classes, input_ch, fc_ch, batch_norm=True, batch_norm_alpha=0.1): 7 | super(ShallowConvNet, self).__init__() 8 | self.batch_norm = batch_norm 9 | self.batch_norm_alpha = batch_norm_alpha 10 | self.n_classes = n_classes 11 | n_ch1 = 40 12 | 13 | if self.batch_norm: 14 | self.layer1 = nn.Sequential( 15 | nn.Conv2d(1, n_ch1, kernel_size=(1, 13), stride=1, padding=(6, 7)), 16 | nn.Conv2d(n_ch1, n_ch1, kernel_size=(input_ch, 1), stride=1, bias=not self.batch_norm), 17 | nn.BatchNorm2d(n_ch1, 18 | momentum=self.batch_norm_alpha, 19 | affine=True, 20 | eps=1e-5)) 21 | 22 | self.fc = nn.Linear(fc_ch, n_classes) 23 | 24 | def forward(self, x): 25 | x = x.permute(0, 3, 1, 2) 26 | x = self.layer1(x) 27 | x = torch.square(x) 28 | x = torch.nn.functional.avg_pool2d(x, (1, 35), (1, 7)) 29 | x = torch.log(x) 30 | x = x.flatten(1) 31 | x = torch.nn.functional.dropout(x) 32 | x = self.fc(x) 33 | return x 34 | 35 | 36 | class ShallowConvNetReduced(nn.Module): 37 | def __init__(self, n_classes, input_ch, fc_ch, batch_norm=True, batch_norm_alpha=0.1): 38 | super(ShallowConvNetReduced, self).__init__() 39 | self.batch_norm = batch_norm 40 | self.batch_norm_alpha = batch_norm_alpha 41 | self.n_classes = n_classes 42 | n_ch1 = 40 43 | 44 | if self.batch_norm: 45 | self.layer1 = nn.Sequential( 46 | nn.Conv2d(1, n_ch1, kernel_size=(1, 13), stride=1, padding=(6, 7)), 47 | nn.Conv2d(n_ch1, n_ch1, kernel_size=(input_ch, 1), stride=1, bias=not self.batch_norm), 48 | nn.BatchNorm2d(n_ch1, 49 | momentum=self.batch_norm_alpha, 50 | affine=True, 51 | eps=1e-5)) 52 | 53 | self.fc = nn.Linear(fc_ch, n_classes) 54 | 55 | def forward(self, x): 56 | x = x.permute(0, 3, 1, 2) 57 | x = self.layer1(x) 58 | x = torch.square(x) 59 | x = torch.nn.functional.avg_pool2d(x, (1, 10), (1, 2)) 60 | x = torch.log(x) 61 | x = x.flatten(1) 62 | x = torch.nn.functional.dropout(x) 63 | x = self.fc(x) 64 | return x 65 | 66 | 67 | class ShallowConvNetFeatures(nn.Module): 68 | def __init__(self, n_classes, input_ch, fc_ch, batch_norm=True, batch_norm_alpha=0.1): 69 | super(ShallowConvNetFeatures, self).__init__() 70 | self.batch_norm = batch_norm 71 | self.batch_norm_alpha = batch_norm_alpha 72 | self.n_classes = n_classes 73 | n_ch1 = 40 74 | 75 | if self.batch_norm: 76 | self.layer1 = nn.Sequential( 77 | nn.Conv2d(1, n_ch1, kernel_size=(1, 13), stride=1, padding=(6, 7)), 78 | nn.Conv2d(n_ch1, n_ch1, kernel_size=(input_ch, 1), stride=1, bias=not self.batch_norm), 79 | nn.BatchNorm2d(n_ch1, 80 | momentum=self.batch_norm_alpha, 81 | affine=True, 82 | eps=1e-5)) 83 | 84 | def forward(self, x): 85 | x = x.permute(0, 3, 1, 2) 86 | x = self.layer1(x) 87 | x = torch.square(x) 88 | x = torch.nn.functional.avg_pool2d(x, (1, 35), (1, 7)) 89 | x = torch.log(x) 90 | x = x.flatten(1) 91 | x = torch.nn.functional.dropout(x) 92 | 93 | return x 94 | 95 | 96 | class ShallowConvNetFeaturesReduced(nn.Module): 97 | def __init__(self, n_classes, input_ch, fc_ch, batch_norm=True, batch_norm_alpha=0.1): 98 | super(ShallowConvNetFeaturesReduced, self).__init__() 99 | self.batch_norm = batch_norm 100 | self.batch_norm_alpha = batch_norm_alpha 101 | self.n_classes = n_classes 102 | n_ch1 = 40 103 | 104 | if self.batch_norm: 105 | self.layer1 = nn.Sequential( 106 | nn.Conv2d(1, n_ch1, kernel_size=(1, 13), stride=1, padding=(6, 7)), 107 | nn.Conv2d(n_ch1, n_ch1, kernel_size=(input_ch, 1), stride=1, bias=not self.batch_norm), 108 | nn.BatchNorm2d(n_ch1, 109 | momentum=self.batch_norm_alpha, 110 | affine=True, 111 | eps=1e-5)) 112 | 113 | def forward(self, x): 114 | x = x.permute(0, 3, 1, 2) 115 | x = self.layer1(x) 116 | x = torch.square(x) 117 | x = torch.nn.functional.avg_pool2d(x, (1, 10), (1, 2)) 118 | x = torch.log(x) 119 | x = x.flatten(1) 120 | x = torch.nn.functional.dropout(x) 121 | 122 | return x 123 | 124 | 125 | class ShallowConvNetClassifier(nn.Module): 126 | def __init__(self, n_classes, input_ch, fc_ch, batch_norm=True, batch_norm_alpha=0.1): 127 | super(ShallowConvNetClassifier, self).__init__() 128 | 129 | self.fc = nn.Linear(fc_ch, n_classes) 130 | 131 | def forward(self, x): 132 | 133 | x = self.fc(x) 134 | return x -------------------------------------------------------------------------------- /models/EEGNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class EEGNet(nn.Module): 6 | def __init__(self, in_chan=0, fc_num=0, out_chann=0): 7 | super(EEGNet, self).__init__() 8 | 9 | # Layer 1 10 | self.conv1 = nn.Conv2d(1, 16, (1, in_chan), padding=0) 11 | self.batchnorm1 = nn.BatchNorm2d(16, False) 12 | 13 | # Layer 2 14 | self.padding1 = nn.ZeroPad2d((16, 17, 0, 1)) 15 | self.conv2 = nn.Conv2d(1, 4, (2, 32)) 16 | self.batchnorm2 = nn.BatchNorm2d(4, False) 17 | self.pooling2 = nn.MaxPool2d((2, 4)) 18 | 19 | # Layer 3 20 | self.padding2 = nn.ZeroPad2d((2, 1, 4, 3)) 21 | self.conv3 = nn.Conv2d(4, 4, (8, 4)) 22 | self.batchnorm3 = nn.BatchNorm2d(4, False) 23 | self.pooling3 = nn.MaxPool2d((2, 4)) 24 | 25 | # FC Layer 26 | self.fc1 = nn.Linear(fc_num, out_chann) 27 | 28 | def forward(self, x): 29 | x = x.permute(0, 3, 2, 1) 30 | # Layer 1 31 | x = F.elu(self.conv1(x)) 32 | x = self.batchnorm1(x) 33 | x = F.dropout(x, 0.25) 34 | x = x.permute(0, 3, 1, 2) 35 | 36 | # Layer 2 37 | x = self.padding1(x) 38 | x = F.elu(self.conv2(x)) 39 | x = self.batchnorm2(x) 40 | x = F.dropout(x, 0.25) 41 | x = self.pooling2(x) 42 | 43 | # Layer 3 44 | x = self.padding2(x) 45 | x = F.elu(self.conv3(x)) 46 | x = self.batchnorm3(x) 47 | x = F.dropout(x, 0.25) 48 | x = self.pooling3(x) 49 | 50 | # FC Layer 51 | x = x.reshape(x.size()[0], -1) 52 | x = self.fc1(x) 53 | return x 54 | 55 | 56 | class EEGNet_features(nn.Module): 57 | def __init__(self, in_chan=0, fc_num=0, out_chann=0): 58 | super(EEGNet_features, self).__init__() 59 | 60 | # Layer 1 61 | self.conv1 = nn.Conv2d(1, 16, (1, in_chan), padding=0) 62 | self.batchnorm1 = nn.BatchNorm2d(16, False) 63 | 64 | # Layer 2 65 | self.padding1 = nn.ZeroPad2d((16, 17, 0, 1)) 66 | self.conv2 = nn.Conv2d(1, 4, (2, 32)) 67 | self.batchnorm2 = nn.BatchNorm2d(4, False) 68 | self.pooling2 = nn.MaxPool2d((2, 4)) 69 | 70 | # Layer 3 71 | self.padding2 = nn.ZeroPad2d((2, 1, 4, 3)) 72 | self.conv3 = nn.Conv2d(4, 4, (8, 4)) 73 | self.batchnorm3 = nn.BatchNorm2d(4, False) 74 | self.pooling3 = nn.MaxPool2d((2, 4)) 75 | 76 | def forward(self, x): 77 | x = x.permute(0, 3, 2, 1) 78 | # Layer 1 79 | x = F.elu(self.conv1(x)) 80 | x = self.batchnorm1(x) 81 | x = F.dropout(x, 0.25) 82 | x = x.permute(0, 3, 1, 2) 83 | 84 | # Layer 2 85 | x = self.padding1(x) 86 | x = F.elu(self.conv2(x)) 87 | x = self.batchnorm2(x) 88 | x = F.dropout(x, 0.25) 89 | x = self.pooling2(x) 90 | 91 | # Layer 3 92 | x = self.padding2(x) 93 | x = F.elu(self.conv3(x)) 94 | x = self.batchnorm3(x) 95 | x = F.dropout(x, 0.25) 96 | x = self.pooling3(x) 97 | 98 | # FC Layer 99 | x = x.reshape(x.size()[0], -1) 100 | return x 101 | 102 | 103 | class EEGNet_classifier(nn.Module): 104 | def __init__(self, in_chan=0, fc_num=0, out_chann=0): 105 | super(EEGNet_classifier, self).__init__() 106 | 107 | # FC Layer 108 | self.fc1 = nn.Linear(fc_num, out_chann) 109 | 110 | def forward(self, x): 111 | x = self.fc1(x) 112 | return x 113 | 114 | 115 | class EEGNet_features1(nn.Module): 116 | def __init__(self, in_chan=0, fc_num=0, out_chann=0): 117 | super(EEGNet_features1, self).__init__() 118 | 119 | # Layer 1 120 | self.conv1 = nn.Conv2d(1, 16, (1, in_chan), padding=0) 121 | self.batchnorm1 = nn.BatchNorm2d(16, False) 122 | 123 | def forward(self, x): 124 | x = x.permute(0, 3, 2, 1) 125 | # Layer 1 126 | x = F.elu(self.conv1(x)) 127 | x = self.batchnorm1(x) 128 | x = F.dropout(x, 0.25) 129 | 130 | return x 131 | 132 | 133 | class EEGNet_latter(nn.Module): 134 | def __init__(self, in_chan=0, fc_num=0, out_chann=0): 135 | super(EEGNet_latter, self).__init__() 136 | 137 | # Layer 2 138 | self.padding1 = nn.ZeroPad2d((16, 17, 0, 1)) 139 | self.conv2 = nn.Conv2d(1, 4, (2, 32)) 140 | self.batchnorm2 = nn.BatchNorm2d(4, False) 141 | self.pooling2 = nn.MaxPool2d((2, 4)) 142 | 143 | # Layer 3 144 | self.padding2 = nn.ZeroPad2d((2, 1, 4, 3)) 145 | self.conv3 = nn.Conv2d(4, 4, (8, 4)) 146 | self.batchnorm3 = nn.BatchNorm2d(4, False) 147 | self.pooling3 = nn.MaxPool2d((2, 4)) 148 | 149 | # FC Layer 150 | self.fc1 = nn.Linear(fc_num, out_chann) 151 | 152 | def forward(self, x): 153 | x = x.permute(0, 3, 1, 2) 154 | 155 | # Layer 2 156 | x = self.padding1(x) 157 | x = F.elu(self.conv2(x)) 158 | x = self.batchnorm2(x) 159 | x = F.dropout(x, 0.25) 160 | x = self.pooling2(x) 161 | 162 | # Layer 3 163 | x = self.padding2(x) 164 | x = F.elu(self.conv3(x)) 165 | x = self.batchnorm3(x) 166 | x = F.dropout(x, 0.25) 167 | x = self.pooling3(x) 168 | 169 | # FC Layer 170 | x = x.reshape(x.size()[0], -1) 171 | x = self.fc1(x) 172 | return x -------------------------------------------------------------------------------- /utils/data_augment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/6/25 21:01 3 | # @Author : wenzhang 4 | # @File : data_augment.py 5 | 6 | import numpy as np 7 | from scipy.signal import hilbert 8 | 9 | 10 | def data_aug(data, labels, size, flag_aug): 11 | # ugments data based on boolean inputs reuse_data, noise_data, neg_data, freq_mod data. 12 | # data: samples * size * n_channels 13 | # size: int(freq * window_size) 14 | # Returns: entire training dataset after data augmentation, and the corresponding labels 15 | 16 | # noise_flag, neg_flag, mult_flag, freq_mod_flag test 75.154 17 | # mult_flag, noise_flag, neg_flag, freq_mod_flag test 76.235 18 | # noise_flag, neg_flag, freq_mod_flag test 76.157 19 | 20 | mult_flag, noise_flag, neg_flag, freq_mod_flag = flag_aug[0], flag_aug[1], flag_aug[2], flag_aug[3] 21 | 22 | n_channels = data.shape[2] 23 | data_out = data # 1 raw features 24 | labels_out = labels 25 | 26 | if mult_flag: # 2 features 27 | mult_data_add, labels_mult = data_mult_f(data, labels, size, n_channels=n_channels) 28 | data_out = np.concatenate([data_out, mult_data_add], axis=0) 29 | labels_out = np.append(labels_out, labels_mult) 30 | if noise_flag: # 1 features 31 | noise_data_add, labels_noise = data_noise_f(data, labels, size, n_channels=n_channels) 32 | data_out = np.concatenate([data_out, noise_data_add], axis=0) 33 | labels_out = np.append(labels_out, labels_noise) 34 | if neg_flag: # 1 features 35 | neg_data_add, labels_neg = data_neg_f(data, labels, size, n_channels=n_channels) 36 | data_out = np.concatenate([data_out, neg_data_add], axis=0) 37 | labels_out = np.append(labels_out, labels_neg) 38 | if freq_mod_flag: # 2 features 39 | freq_data_add, labels_freq = freq_mod_f(data, labels, size, n_channels=n_channels) 40 | data_out = np.concatenate([data_out, freq_data_add], axis=0) 41 | labels_out = np.append(labels_out, labels_freq) 42 | 43 | # 最终输出data格式为 44 | # raw 144, mult_add 144, mult_reduce 144, noise 144, neg 144, freq1 144, freq2 144 45 | return data_out, labels_out 46 | 47 | 48 | def data_noise_f(data, labels, size, n_channels=22): 49 | new_data = [] 50 | new_labels = [] 51 | noise_mod_val = 2 52 | # print("noise mod: {}".format(noise_mod_val)) 53 | for i in range(len(labels)): 54 | if labels[i] >= 0: 55 | stddev_t = np.std(data[i]) 56 | rand_t = np.random.rand(data[i].shape[0], data[i].shape[1]) 57 | rand_t = rand_t - 0.5 58 | to_add_t = rand_t * stddev_t / noise_mod_val 59 | data_t = data[i] + to_add_t 60 | new_data.append(data_t) 61 | new_labels.append(labels[i]) 62 | 63 | new_data_ar = np.array(new_data).reshape([-1, size, n_channels]) 64 | new_labels = np.array(new_labels) 65 | 66 | return new_data_ar, new_labels 67 | 68 | 69 | def data_mult_f(data, labels, size, n_channels=22): 70 | new_data = [] 71 | new_labels = [] 72 | mult_mod = 0.05 73 | # print("mult mod: {}".format(mult_mod)) 74 | for i in range(len(labels)): 75 | if labels[i] >= 0: 76 | # print(data[i]) 77 | data_t = data[i] * (1 + mult_mod) 78 | new_data.append(data_t) 79 | new_labels.append(labels[i]) 80 | 81 | for i in range(len(labels)): 82 | if labels[i] >= 0: 83 | data_t = data[i] * (1 - mult_mod) 84 | new_data.append(data_t) 85 | new_labels.append(labels[i]) 86 | 87 | new_data_ar = np.array(new_data).reshape([-1, size, n_channels]) 88 | new_labels = np.array(new_labels) 89 | 90 | return new_data_ar, new_labels 91 | 92 | 93 | def data_neg_f(data, labels, size, n_channels=22): 94 | # Returns: data double the size of the input over time, with new data 95 | # being a reflection along the amplitude 96 | 97 | new_data = [] 98 | new_labels = [] 99 | for i in range(len(labels)): 100 | if labels[i] >= 0: 101 | data_t = -1 * data[i] 102 | data_t = data_t - np.min(data_t) 103 | new_data.append(data_t) 104 | new_labels.append(labels[i]) 105 | 106 | new_data_ar = np.array(new_data).reshape([-1, size, n_channels]) 107 | new_labels = np.array(new_labels) 108 | 109 | return new_data_ar, new_labels 110 | 111 | 112 | def freq_mod_f(data, labels, size, n_channels=22): 113 | new_data = [] 114 | new_labels = [] 115 | # print(data.shape) 116 | freq_mod = 0.2 117 | # print("freq mod: {}".format(freq_mod)) 118 | for i in range(len(labels)): 119 | if labels[i] >= 0: 120 | low_shift = freq_shift(data[i], -freq_mod, num_channels=n_channels) 121 | new_data.append(low_shift) 122 | new_labels.append(labels[i]) 123 | 124 | for i in range(len(labels)): 125 | if labels[i] >= 0: 126 | high_shift = freq_shift(data[i], freq_mod, num_channels=n_channels) 127 | new_data.append(high_shift) 128 | new_labels.append(labels[i]) 129 | 130 | new_data_ar = np.array(new_data).reshape([-1, size, n_channels]) 131 | new_labels = np.array(new_labels) 132 | 133 | return new_data_ar, new_labels 134 | 135 | 136 | def freq_shift(x, f_shift, dt=1 / 250, num_channels=22): 137 | shifted_sig = np.zeros((x.shape)) 138 | len_x = len(x) 139 | padding_len = 2 ** nextpow2(len_x) 140 | padding = np.zeros((padding_len - len_x, num_channels)) 141 | with_padding = np.vstack((x, padding)) 142 | hilb_T = hilbert(with_padding, axis=0) 143 | t = np.arange(0, padding_len) 144 | shift_func = np.exp(2j * np.pi * f_shift * dt * t) 145 | for i in range(num_channels): 146 | shifted_sig[:, i] = (hilb_T[:, i] * shift_func)[:len_x].real 147 | 148 | return shifted_sig 149 | 150 | 151 | def nextpow2(x): 152 | return int(np.ceil(np.log2(np.abs(x)))) 153 | -------------------------------------------------------------------------------- /models/EEGNetv4.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class EEGNet(nn.Module): 7 | """ 8 | :param 9 | """ 10 | def __init__(self, 11 | Chans=None, 12 | Samples=None, 13 | n_classes=None, 14 | kernLenght=64, 15 | F1=4, 16 | D=2, 17 | F2=8, 18 | dropoutRate=0.25, # cross_subject: 0.25; within_subject: 0.5 19 | norm_rate=0.5): 20 | super(EEGNet, self).__init__() 21 | 22 | self.n_classes = n_classes 23 | self.Chans = Chans 24 | self.Samples = Samples 25 | self.kernLenght = kernLenght 26 | self.F1 = F1 27 | self.D = D 28 | self.F2 = F2 29 | self.dropoutRate = dropoutRate 30 | self.norm_rate = norm_rate 31 | 32 | self.block1 = nn.Sequential( 33 | nn.ZeroPad2d((self.kernLenght // 2 - 1, 34 | self.kernLenght - self.kernLenght // 2, 0, 35 | 0)), # left, right, up, bottom 36 | nn.Conv2d(in_channels=1, 37 | out_channels=self.F1, 38 | kernel_size=(1, self.kernLenght), 39 | stride=1, 40 | bias=False), 41 | nn.BatchNorm2d(num_features=self.F1), 42 | # DepthwiseConv2d 43 | nn.Conv2d(in_channels=self.F1, 44 | out_channels=self.F1 * self.D, 45 | kernel_size=(self.Chans, 1), 46 | groups=self.F1, 47 | bias=False), 48 | nn.BatchNorm2d(num_features=self.F1 * self.D), 49 | nn.ELU(), 50 | nn.AvgPool2d((1, 4)), 51 | nn.Dropout(p=self.dropoutRate)) 52 | 53 | self.block2 = nn.Sequential( 54 | nn.ZeroPad2d((7, 8, 0, 0)), 55 | # SeparableConv2d 56 | nn.Conv2d(in_channels=self.F1 * self.D, 57 | out_channels=self.F1 * self.D, 58 | kernel_size=(1, 16), 59 | stride=1, 60 | groups=self.F1 * self.D, 61 | bias=False), 62 | nn.Conv2d(in_channels=self.F1 * self.D, 63 | out_channels=self.F2, 64 | kernel_size=(1, 1), 65 | stride=1, 66 | bias=False), 67 | nn.BatchNorm2d(num_features=self.F2), 68 | nn.ELU(), 69 | nn.AvgPool2d((1, 8)), 70 | nn.Dropout(self.dropoutRate)) 71 | self.classifier_block = nn.Sequential( 72 | nn.Linear(in_features=self.F2 * (self.Samples // (4 * 8)), 73 | out_features=self.n_classes, 74 | bias=True)) 75 | 76 | def forward(self, x): 77 | x = x.permute(0, 3, 1, 2) 78 | output = self.block1(x) 79 | output = self.block2(output) 80 | output = output.reshape(output.size(0), -1) 81 | output = self.classifier_block(output) 82 | return output 83 | 84 | 85 | class EEGNet_features(nn.Module): 86 | """ 87 | :param 88 | """ 89 | def __init__(self, 90 | Chans=None, 91 | Samples=None, 92 | n_classes=None, 93 | kernLenght=64, 94 | F1=4, 95 | D=2, 96 | F2=8, 97 | dropoutRate=0.25, # cross_subject: 0.25; within_subject: 0.5 98 | norm_rate=0.5): 99 | super(EEGNet_features, self).__init__() 100 | 101 | self.n_classes = n_classes 102 | self.Chans = Chans 103 | self.Samples = Samples 104 | self.kernLenght = kernLenght 105 | self.F1 = F1 106 | self.D = D 107 | self.F2 = F2 108 | self.dropoutRate = dropoutRate 109 | self.norm_rate = norm_rate 110 | 111 | self.block1 = nn.Sequential( 112 | nn.ZeroPad2d((self.kernLenght // 2 - 1, 113 | self.kernLenght - self.kernLenght // 2, 0, 114 | 0)), # left, right, up, bottom 115 | nn.Conv2d(in_channels=1, 116 | out_channels=self.F1, 117 | kernel_size=(1, self.kernLenght), 118 | stride=1, 119 | bias=False), 120 | nn.BatchNorm2d(num_features=self.F1), 121 | # DepthwiseConv2d 122 | nn.Conv2d(in_channels=self.F1, 123 | out_channels=self.F1 * self.D, 124 | kernel_size=(self.Chans, 1), 125 | groups=self.F1, 126 | bias=False), 127 | nn.BatchNorm2d(num_features=self.F1 * self.D), 128 | nn.ELU(), 129 | nn.AvgPool2d((1, 4)), 130 | nn.Dropout(p=self.dropoutRate)) 131 | 132 | self.block2 = nn.Sequential( 133 | nn.ZeroPad2d((7, 8, 0, 0)), 134 | # SeparableConv2d 135 | nn.Conv2d(in_channels=self.F1 * self.D, 136 | out_channels=self.F1 * self.D, 137 | kernel_size=(1, 16), 138 | stride=1, 139 | groups=self.F1 * self.D, 140 | bias=False), 141 | nn.Conv2d(in_channels=self.F1 * self.D, 142 | out_channels=self.F2, 143 | kernel_size=(1, 1), 144 | stride=1, 145 | bias=False), 146 | nn.BatchNorm2d(num_features=self.F2), 147 | nn.ELU(), 148 | nn.AvgPool2d((1, 8)), 149 | nn.Dropout(self.dropoutRate)) 150 | 151 | def forward(self, x): 152 | x = x.permute(0, 3, 1, 2) 153 | output = self.block1(x) 154 | output = self.block2(output) 155 | output = output.reshape(output.size(0), -1) 156 | 157 | return output 158 | 159 | class EEGNet_classifier(nn.Module): 160 | """ 161 | :param 162 | """ 163 | def __init__(self, 164 | Chans=None, 165 | Samples=None, 166 | n_classes=None, 167 | kernLenght=64, 168 | F1=4, 169 | D=2, 170 | F2=8, 171 | dropoutRate=0.25, # cross_subject: 0.25; within_subject: 0.5 172 | norm_rate=0.5): 173 | super(EEGNet_classifier, self).__init__() 174 | 175 | self.n_classes = n_classes 176 | self.Chans = Chans 177 | self.Samples = Samples 178 | self.kernLenght = kernLenght 179 | self.F1 = F1 180 | self.D = D 181 | self.F2 = F2 182 | self.dropoutRate = dropoutRate 183 | self.norm_rate = norm_rate 184 | 185 | self.classifier_block = nn.Sequential( 186 | nn.Linear(in_features=self.F2 * (self.Samples // (4 * 8)), 187 | out_features=self.n_classes, 188 | bias=True)) 189 | 190 | def forward(self, x): 191 | output = self.classifier_block(x) 192 | return output -------------------------------------------------------------------------------- /MAML.py: -------------------------------------------------------------------------------- 1 | import learn2learn as l2l 2 | import numpy as np 3 | import torch 4 | 5 | from learn2learn.data.transforms import NWays, KShots, LoadData 6 | 7 | from models.EEGNet import EEGNet 8 | from models.ShallowConvNet import ShallowConvNet, ShallowConvNetReduced 9 | from EEG_cross_subject_loader import EEG_loader 10 | 11 | import random 12 | 13 | 14 | def main( 15 | test_subj=None, 16 | ways=None, 17 | shots=None, 18 | meta_lr=None, 19 | fast_lr=None, 20 | meta_batch_size=None, 21 | adaptation_steps=None, 22 | num_iterations=None, 23 | cuda=None, 24 | seed=None, 25 | model_name=None, 26 | dataset=None, 27 | se=None, 28 | ): 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | device = torch.device('cpu') 33 | if cuda: 34 | torch.cuda.manual_seed(seed) 35 | device = torch.device('cuda:4') 36 | print('using cuda...') 37 | 38 | data = EEG_loader(test_subj=test_subj, dataset=dataset) 39 | train_x_arr = data.train_x 40 | train_y_arr = data.train_y 41 | train_x_arr_tmp = [] 42 | train_y_arr_tmp = [] 43 | for train_x, train_y in zip(train_x_arr, train_y_arr): 44 | train_x_arr_tmp.append(train_x) 45 | train_y_arr_tmp.append(train_y) 46 | train_x_arr_tmp = np.concatenate(train_x_arr_tmp, axis=0) 47 | train_y_arr_tmp = np.concatenate(train_y_arr_tmp, axis=0) 48 | 49 | tensor_train_x, tensor_train_y = torch.from_numpy(train_x_arr_tmp).unsqueeze_(3).to( 50 | torch.float32), torch.squeeze(torch.from_numpy(train_y_arr_tmp), 1).to(torch.long) 51 | train_torch_dataset = torch.utils.data.TensorDataset(tensor_train_x, tensor_train_y) 52 | train_dataset = l2l.data.MetaDataset(train_torch_dataset) 53 | train_task = l2l.data.TaskDataset(train_dataset, 54 | task_transforms=[ 55 | NWays(train_dataset, n=ways), 56 | KShots(train_dataset, k=2 * shots), 57 | LoadData(train_dataset), 58 | ], 59 | num_tasks=meta_batch_size) 60 | del train_x_arr, train_y_arr, train_x_arr_tmp, train_y_arr_tmp , train_dataset, train_torch_dataset, tensor_train_x, tensor_train_y 61 | 62 | if model_name == 'ShallowConvNet': 63 | if dataset == 'MI1': 64 | model = ShallowConvNet(4, 22, 16640) 65 | if dataset == 'MI2': 66 | model = ShallowConvNet(2, 15, 26520) 67 | if dataset == 'ERP1': 68 | model = ShallowConvNetReduced(2, 16, 6760) 69 | if dataset == 'ERP2': 70 | model = ShallowConvNet(2, 56, 17160) 71 | elif model_name == 'EEGNet': 72 | if dataset == 'MI1': 73 | model = EEGNet(22, 256, 4) 74 | if dataset == 'MI2': 75 | model = EEGNet(15, 384, 2) 76 | if dataset == 'ERP1': 77 | model = EEGNet(16, 32, 2) 78 | if dataset == 'ERP2': 79 | model = EEGNet(56, 256, 2) 80 | 81 | # subject number 82 | k = -1 83 | if dataset == 'MI1': 84 | k = 9 85 | if dataset == 'MI2': 86 | k = 14 87 | if dataset == 'ERP1': 88 | k = 10 89 | if dataset == 'ERP2': 90 | k = 16 91 | 92 | model.to(device) 93 | maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=True, allow_nograd=True) 94 | 95 | opt = torch.optim.Adam(maml.parameters(), lr=meta_lr) 96 | loss = torch.nn.CrossEntropyLoss() 97 | 98 | print('start training...') 99 | for iteration in range(1, num_iterations + 1): 100 | 101 | opt.zero_grad() 102 | meta_train_error = 0.0 103 | meta_train_accuracy = 0.0 104 | 105 | for batch in train_task: 106 | learner = maml.clone() 107 | evaluation_error, evaluation_accuracy = fast_adapt(batch, 108 | learner, 109 | loss, 110 | adaptation_steps, 111 | shots, 112 | ways, 113 | device) 114 | evaluation_error.backward() 115 | meta_train_error += evaluation_error.item() 116 | meta_train_accuracy += evaluation_accuracy.item() 117 | 118 | print('Iteration', iteration) 119 | print('Meta Train Error', meta_train_error / (meta_batch_size)) 120 | print('Meta Train Accuracy', meta_train_accuracy / (meta_batch_size)) 121 | 122 | s = dataset + '_test_subj_' + str(test_subj) + '_shots_' + str(shots) + '_meta_lr_' + str( 123 | meta_lr) + '_fast_lr_' + \ 124 | str(fast_lr) + '_meta_batch_size_' + str(meta_batch_size) + '_adaptation_steps_' + str( 125 | adaptation_steps) + str(model_name) 126 | 127 | # Average the accumulated gradients and optimize 128 | for p in maml.parameters(): 129 | if p.grad is None: 130 | continue 131 | p.grad.data.mul_(1.0 / (meta_batch_size)) 132 | opt.step() 133 | 134 | if iteration % 50 == 0: 135 | print('saving model...') 136 | 137 | torch.save(model, 138 | './runs/' + str(dataset) + '/maml_' + s + '_num_iterations_' + str(iteration) + 'seed' + str(se) + '.pth') 139 | 140 | 141 | def accuracy(predictions, targets): 142 | predictions = predictions.argmax(dim=1).view(targets.shape) 143 | return (predictions == targets).sum().float() / targets.size(0) 144 | 145 | 146 | def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device): 147 | data, labels = batch 148 | data, labels = data.to(device), labels.to(device) 149 | 150 | # Separate data into adaptation/evalutation sets 151 | adaptation_indices = np.zeros(data.size(0), dtype=bool) 152 | adaptation_indices[np.arange(shots * ways) * 2] = True 153 | evaluation_indices = torch.from_numpy(~adaptation_indices) 154 | adaptation_indices = torch.from_numpy(adaptation_indices) 155 | adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices] 156 | evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices] 157 | 158 | # Adapt the model 159 | for step in range(adaptation_steps): 160 | train_error = loss(learner(adaptation_data), adaptation_labels) 161 | learner.adapt(train_error) 162 | 163 | # Evaluate the adapted model 164 | predictions = learner(evaluation_data) 165 | valid_error = loss(predictions, evaluation_labels) 166 | valid_accuracy = accuracy(predictions, evaluation_labels) 167 | return valid_error, valid_accuracy 168 | 169 | 170 | if __name__ == '__main__': 171 | 172 | meta_lr = 0.001 173 | fast_lr = 0.001 174 | shots = 10 175 | for model_name in ['EEGNet', 'ShallowConvNet']: 176 | for dataset in ['MI1', 'MI2', 'ERP1', 'ERP2']: 177 | if dataset == 'MI1': 178 | subj_num = 9 179 | meta_batch_size = 576 * 8 // (2 * 4 * shots) 180 | elif dataset == 'MI2': 181 | subj_num = 14 182 | meta_batch_size = 100 * 13 // (2 * 2 * shots) 183 | elif dataset == 'ERP1': 184 | subj_num = 10 185 | meta_batch_size = 575 * 9 // (2 * 2 * shots) 186 | elif dataset == 'ERP2': 187 | subj_num = 16 188 | meta_batch_size = 340 * 15 // (2 * 2 * shots) 189 | 190 | if dataset == 'MI1': 191 | ways = 4 192 | else: 193 | ways = 2 194 | 195 | for test_subj in range(0, subj_num): 196 | for seed in range(0, 10): 197 | print('MAML', dataset, model_name) 198 | print('subj', test_subj, 'seed', seed) 199 | main(test_subj=test_subj, 200 | ways=ways, 201 | shots=shots, 202 | meta_lr=meta_lr, 203 | fast_lr=fast_lr, 204 | meta_batch_size=meta_batch_size, 205 | adaptation_steps=1, 206 | num_iterations=200, 207 | cuda=True, 208 | seed=42, 209 | model_name=model_name, 210 | dataset=dataset, 211 | se=seed, 212 | ) 213 | -------------------------------------------------------------------------------- /ANIL.py: -------------------------------------------------------------------------------- 1 | import learn2learn as l2l 2 | import numpy as np 3 | import torch 4 | 5 | from models.EEGNet import EEGNet_features, EEGNet_classifier 6 | from models.ShallowConvNet import ShallowConvNetFeatures, ShallowConvNetClassifier, ShallowConvNetFeaturesReduced 7 | from EEG_cross_subject_loader import EEG_loader 8 | 9 | import random 10 | 11 | 12 | def main( 13 | test_subj=None, 14 | ways=None, 15 | shots=None, 16 | meta_lr=None, 17 | fast_lr=None, 18 | meta_batch_size=None, 19 | adaptation_steps=None, 20 | num_iterations=None, 21 | cuda=None, 22 | seed=None, 23 | model_name=None, 24 | dataset=None, 25 | se=None, 26 | ): 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | device = torch.device('cpu') 31 | if cuda: 32 | torch.cuda.manual_seed(seed) 33 | device = torch.device('cuda:2') 34 | print('using cuda...') 35 | 36 | data = EEG_loader(test_subj=test_subj, dataset=dataset) 37 | train_x_arr = data.train_x 38 | train_y_arr = data.train_y 39 | train_x_arr_tmp = [] 40 | train_y_arr_tmp = [] 41 | for train_x, train_y in zip(train_x_arr, train_y_arr): 42 | train_x_arr_tmp.append(train_x) 43 | train_y_arr_tmp.append(train_y) 44 | 45 | l2l_train_tasks_arr = [] 46 | 47 | for i in range(len(train_x_arr_tmp)): 48 | tensor_train_x, tensor_train_y = torch.from_numpy(train_x_arr_tmp[i]).unsqueeze_(3).to( 49 | torch.float32), torch.squeeze(torch.from_numpy(train_y_arr_tmp[i]), 1).to(torch.long) 50 | train_torch_dataset = torch.utils.data.TensorDataset(tensor_train_x, tensor_train_y) 51 | train_loader = torch.utils.data.DataLoader(train_torch_dataset, batch_size=len(train_torch_dataset)) 52 | l2l_train_tasks_arr.append(train_loader) 53 | del train_x_arr, train_y_arr, train_x_arr_tmp, train_y_arr_tmp, data, train_torch_dataset, tensor_train_x, tensor_train_y 54 | 55 | if model_name == 'ShallowConvNet': 56 | if dataset == 'MI1': 57 | model1 = ShallowConvNetFeatures(4, 22, 16640) 58 | model2 = ShallowConvNetClassifier(4, 22, 16640) 59 | if dataset == 'MI2': 60 | model1 = ShallowConvNetFeatures(2, 15, 26520) 61 | model2 = ShallowConvNetClassifier(2, 15, 26520) 62 | if dataset == 'ERP1': 63 | model1 = ShallowConvNetFeaturesReduced(2, 16, 6760) 64 | model2 = ShallowConvNetClassifier(2, 16, 6760) 65 | if dataset == 'ERP2': 66 | model1 = ShallowConvNetFeatures(2, 56, 17160) 67 | model2 = ShallowConvNetClassifier(2, 56, 17160) 68 | elif model_name == 'EEGNet': 69 | if dataset == 'MI1': 70 | model1 = EEGNet_features(22, 256, 4) 71 | model2 = EEGNet_classifier(22, 256, 4) 72 | if dataset == 'MI2': 73 | model1 = EEGNet_features(15, 384, 2) 74 | model2 = EEGNet_classifier(15, 384, 2) 75 | if dataset == 'ERP1': 76 | model1 = EEGNet_features(16, 32, 2) 77 | model2 = EEGNet_classifier(16, 32, 2) 78 | if dataset == 'ERP2': 79 | model1 = EEGNet_features(56, 256, 2) 80 | model2 = EEGNet_classifier(56, 256, 2) 81 | 82 | # subject number 83 | k = -1 84 | if dataset == 'MI1': 85 | k = 9 86 | if dataset == 'MI2': 87 | k = 14 88 | if dataset == 'ERP1': 89 | k = 10 90 | if dataset == 'ERP2': 91 | k = 16 92 | 93 | features = model1 94 | head = model2 95 | head.to(device) 96 | features = l2l.algorithms.MAML(features, lr=fast_lr, first_order=True, allow_nograd=True) 97 | features.to(device) 98 | 99 | all_parameters = list(features.parameters()) + list(head.parameters()) 100 | opt = torch.optim.Adam(all_parameters, lr=meta_lr) 101 | loss = torch.nn.CrossEntropyLoss() 102 | 103 | print('start training...') 104 | for iteration in range(1, num_iterations + 1): 105 | 106 | opt.zero_grad() 107 | 108 | meta_train_error = 0.0 109 | meta_train_accuracy = 0.0 110 | cnt = 0 111 | 112 | for tasks_subj_ind in range(0, len(l2l_train_tasks_arr) - 1, 2): 113 | for train_task, val_task in zip(l2l_train_tasks_arr[tasks_subj_ind], 114 | l2l_train_tasks_arr[tasks_subj_ind + 1]): 115 | learner = features.clone() 116 | train_error, train_accuracy = fast_adapt(train_task, 117 | val_task, 118 | learner, 119 | head, 120 | loss, 121 | adaptation_steps, 122 | shots, 123 | ways, 124 | device) 125 | 126 | train_error.backward() 127 | meta_train_error += train_error.item() 128 | meta_train_accuracy += train_accuracy.item() 129 | 130 | cnt += 1 131 | 132 | print('Iteration', iteration) 133 | if iteration % 50 == 0: 134 | print('saving model...') 135 | 136 | torch.save(features, 137 | './runs/' + str(dataset) + '/anil_model1_' + s + '_num_iterations_' + str( 138 | iteration) + 'seed' + str(se) + '.pt') 139 | 140 | torch.save(head, 141 | './runs/' + str(dataset) + '/anil_model2_' + s + '_num_iterations_' + str( 142 | iteration) + 'seed' + str(se) + '.pt') 143 | 144 | print('Meta Train Error', meta_train_error / cnt) 145 | print('Meta Train Accuracy', meta_train_accuracy / cnt) 146 | 147 | s = dataset + '_test_subj_' + str(test_subj) + '_shots_' + str(shots) + '_meta_lr_' + str( 148 | meta_lr) + '_fast_lr_' + \ 149 | str(fast_lr) + '_meta_batch_size_' + str(meta_batch_size) + '_adaptation_steps_' + str( 150 | adaptation_steps) + str(model_name) 151 | 152 | # Average the accumulated gradients and optimize 153 | for p in all_parameters: 154 | if p.grad is None: 155 | continue 156 | p.grad.data.mul_(1.0 / cnt) 157 | opt.step() 158 | 159 | 160 | def accuracy(predictions, targets): 161 | predictions = predictions.argmax(dim=1).view(targets.shape) 162 | return (predictions == targets).sum().float() / targets.size(0) 163 | 164 | 165 | def fast_adapt(adaptation_batch, evaluation_batch, features, head, loss, adaptation_steps, shots, ways, device): 166 | adaptation_data, adaptation_labels = adaptation_batch 167 | evaluation_data, evaluation_labels = evaluation_batch 168 | adaptation_data, adaptation_labels = adaptation_data.to(device), adaptation_labels.to(device) 169 | evaluation_data, evaluation_labels = evaluation_data.to(device), evaluation_labels.to(device) 170 | 171 | x = features(adaptation_data) 172 | 173 | # Adapt the model 174 | for step in range(adaptation_steps): 175 | train_error = loss(head(x), adaptation_labels) 176 | features.adapt(train_error) 177 | 178 | x = features(evaluation_data) 179 | 180 | # Evaluate the adapted model 181 | predictions = head(x) 182 | valid_error = loss(predictions, evaluation_labels) 183 | 184 | valid_accuracy = accuracy(predictions, evaluation_labels) 185 | return valid_error, valid_accuracy 186 | 187 | 188 | if __name__ == '__main__': 189 | 190 | meta_lr = 0.001 191 | fast_lr = 0.001 192 | for model_name in ['EEGNet', 'ShallowConvNet']: 193 | for dataset in ['MI1', 'MI2', 'ERP1', 'ERP2']: 194 | if dataset == 'MI1': 195 | subj_num = 9 196 | elif dataset == 'MI2': 197 | subj_num = 14 198 | elif dataset == 'ERP1': 199 | subj_num = 10 200 | elif dataset == 'ERP2': 201 | subj_num = 16 202 | 203 | if dataset == 'MI1': 204 | ways = 4 205 | else: 206 | ways = 2 207 | 208 | for test_subj in range(0, subj_num): 209 | for seed in range(0, 10): 210 | print('ANIL', dataset, model_name) 211 | print('subj', test_subj, 'seed', seed) 212 | main(test_subj=test_subj, 213 | ways=ways, 214 | shots=1, 215 | meta_lr=meta_lr, 216 | fast_lr=fast_lr, 217 | meta_batch_size=1, 218 | adaptation_steps=1, 219 | num_iterations=200, 220 | cuda=True, 221 | seed=42, 222 | model_name=model_name, 223 | dataset=dataset, 224 | se=seed, 225 | ) 226 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/12/20 1:06 下午 3 | # @Author : wenzhang 4 | # @File : utils.py 5 | import os.path as osp 6 | import os 7 | import numpy as np 8 | import random 9 | import torch as tr 10 | import torch.nn as nn 11 | import torch.utils.data as Data 12 | from sklearn.metrics import roc_auc_score, confusion_matrix 13 | from sklearn.metrics import accuracy_score, roc_auc_score 14 | 15 | 16 | def op_copy(optimizer): 17 | for param_group in optimizer.param_groups: 18 | param_group['lr0'] = param_group['lr'] 19 | return optimizer 20 | 21 | 22 | def fix_random_seed(SEED): 23 | tr.manual_seed(SEED) 24 | tr.cuda.manual_seed(SEED) 25 | np.random.seed(SEED) 26 | random.seed(SEED) 27 | 28 | 29 | def create_folder(dir_name, data_env, win_root): 30 | if not osp.exists(dir_name): 31 | os.system('mkdir -p ' + dir_name) 32 | if not osp.exists(dir_name): 33 | if data_env == 'gpu': 34 | os.mkdir(dir_name) 35 | elif data_env == 'local': 36 | os.makedirs(win_root + dir_name) 37 | 38 | 39 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 40 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 41 | for param_group in optimizer.param_groups: 42 | param_group['lr'] = param_group['lr0'] * decay 43 | param_group['weight_decay'] = 1e-3 44 | param_group['momentum'] = 0.9 45 | param_group['nesterov'] = True 46 | return optimizer 47 | 48 | 49 | def lr_scheduler_full(optimizer, init_lr, iter_num, max_iter, gamma=10, power=0.75): 50 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 51 | for param_group in optimizer.param_groups: 52 | param_group['lr'] = init_lr * decay 53 | param_group['weight_decay'] = 1e-3 54 | param_group['momentum'] = 0.9 55 | param_group['nesterov'] = True 56 | return optimizer 57 | 58 | def cal_auc_roc(loader, netF, netC, device): 59 | start_test = True 60 | with tr.no_grad(): 61 | iter_test = iter(loader) 62 | for i in range(len(loader)): 63 | data = iter_test.next() 64 | inputs = data[0].to(device) 65 | labels = data[1].float() 66 | outputs = netC(netF(inputs)) 67 | if start_test: 68 | all_output = outputs.float().cpu() 69 | all_label = labels 70 | start_test = False 71 | else: 72 | all_output = tr.cat((all_output, outputs.float().cpu()), 0) 73 | all_label = tr.cat((all_label, labels), 0) 74 | 75 | all_output = nn.Softmax(dim=1)(all_output) 76 | _, predict = tr.max(all_output, 1) 77 | #accuracy = tr.sum(tr.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 78 | print('using roc_auc') 79 | #print(all_label.shape, tr.squeeze(predict).float().shape) 80 | score = roc_auc_score(all_label, tr.squeeze(predict).float()) 81 | 82 | return score * 100 83 | 84 | def cal_acc(loader, netF, netC, device): 85 | start_test = True 86 | with tr.no_grad(): 87 | iter_test = iter(loader) 88 | for i in range(len(loader)): 89 | data = iter_test.next() 90 | inputs = data[0].to(device) 91 | labels = data[1].float() 92 | outputs = netC(netF(inputs)) 93 | if start_test: 94 | all_output = outputs.float().cpu() 95 | all_label = labels 96 | start_test = False 97 | else: 98 | all_output = tr.cat((all_output, outputs.float().cpu()), 0) 99 | all_label = tr.cat((all_label, labels), 0) 100 | 101 | all_output = nn.Softmax(dim=1)(all_output) 102 | _, predict = tr.max(all_output, 1) 103 | accuracy = tr.sum(tr.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 104 | 105 | return accuracy * 100 106 | 107 | 108 | def cal_acc_comb(loader, model, flag=True, fc=None, device=None): 109 | print('here1') 110 | start_test = True 111 | with tr.no_grad(): 112 | iter_test = iter(loader) 113 | for i in range(len(loader)): 114 | data = iter_test.next() 115 | inputs = data[0] 116 | labels = data[1] 117 | inputs = inputs.to(device) 118 | if flag: 119 | _, outputs = model(inputs) 120 | else: 121 | if fc is not None: 122 | feas, outputs = model(inputs) 123 | outputs = fc(feas) 124 | else: 125 | outputs = model(inputs) 126 | if start_test: 127 | all_output = outputs.float().cpu() 128 | all_label = labels.float() 129 | start_test = False 130 | else: 131 | all_output = tr.cat((all_output, outputs.float().cpu()), 0) 132 | all_label = tr.cat((all_label, labels.float()), 0) 133 | all_output = nn.Softmax(dim=1)(all_output) 134 | _, predict = tr.max(all_output, 1) 135 | predict = predict.float().numpy() 136 | all_label = all_label.numpy() 137 | try: 138 | auc = roc_auc_score(all_label, predict) 139 | except ValueError: 140 | pass 141 | tn, fp, fn, tp = confusion_matrix(all_label, predict).ravel() 142 | acc = (tp + tn) / (tp + tn + fp + fn) 143 | sen = tp / (tp + fn) 144 | spec = tn / (tn + fp) 145 | 146 | return acc * 100, sen * 100, spec * 100, auc * 100 147 | 148 | 149 | def cal_acc_multi(loader, netF_list, netC_list, args, weight_epoch=None, netG_list=None, device=None): 150 | print('here2') 151 | num_src = len(netF_list) 152 | for i in range(len(netF_list)): netF_list[i].eval() 153 | 154 | if args.use_weight: 155 | if args.method == 'msdt': 156 | domain_weight = weight_epoch.detach() 157 | # tmp_weight = np.round(tr.squeeze(domain_weight, 0).t().cpu().detach().numpy().flatten(), 3) 158 | # print('\ntest domain weight: ', tmp_weight) 159 | else: 160 | domain_weight = tr.Tensor([1 / num_src] * num_src).reshape([1, num_src, 1]).to(device) 161 | 162 | start_test = True 163 | with tr.no_grad(): 164 | iter_test = iter(loader) 165 | for _ in range(len(loader)): 166 | data = iter_test.next() 167 | inputs, labels = data[0].to(device), data[1] 168 | 169 | if args.use_weight: 170 | if args.method == 'decision': 171 | weights_all = tr.ones(inputs.shape[0], len(args.src)) 172 | tmp_output = tr.zeros(len(args.src), inputs.shape[0], args.class_num) 173 | for i in range(len(args.src)): 174 | tmp_output[i] = netC_list[i](netF_list[i](inputs)) 175 | weights_all[:, i] = netG_list[i](tmp_output[i]).squeeze() 176 | z = tr.sum(weights_all, dim=1) + 1e-16 177 | weights_all = tr.transpose(tr.transpose(weights_all, 0, 1) / z, 0, 1) 178 | weights_domain = tr.sum(weights_all, dim=0) / tr.sum(weights_all) 179 | domain_weight = weights_domain.reshape([1, num_src, 1]).to(device) 180 | 181 | outputs_all = tr.cat([netC_list[i](netF_list[i](inputs)).unsqueeze(1) for i in range(num_src)], 1).to(device) 182 | preds = tr.softmax(outputs_all, dim=2) 183 | outputs_all_w = (preds * domain_weight).sum(dim=1).to(device) 184 | 185 | if start_test: 186 | all_output = outputs_all_w.float().cpu() 187 | all_label = labels.float() 188 | start_test = False 189 | else: 190 | all_output = tr.cat((all_output, outputs_all_w.float().cpu()), 0) 191 | all_label = tr.cat((all_label, labels.float()), 0) 192 | _, predict = tr.max(all_output, 1) 193 | predict = predict.float().numpy() 194 | all_label = all_label.numpy() 195 | try: 196 | auc = roc_auc_score(all_label, predict) 197 | except ValueError: 198 | pass 199 | tn, fp, fn, tp = confusion_matrix(all_label, predict).ravel() 200 | acc = (tp + tn) / (tp + tn + fp + fn) 201 | sen = tp / (tp + fn) 202 | spec = tn / (tn + fp) 203 | 204 | for i in range(len(netF_list)): netF_list[i].train() 205 | 206 | return acc * 100, sen * 100, spec * 100, auc * 100 207 | 208 | 209 | 210 | def data_loader(Xs=None, Ys=None, Xt=None, Yt=None, args=None): 211 | dset_loaders = {} 212 | train_bs = args.batch_size 213 | 214 | if Xs != None: 215 | # 随机打乱会导致训练结果偏高,不影响测试 216 | src_idx = np.arange(len(Ys.numpy())) 217 | if args.validation == 'random': # for SEED 218 | num_train = int(0.9 * len(src_idx)) 219 | tr.manual_seed(args.SEED) 220 | id_train, id_val = tr.utils.data.random_split(src_idx, [num_train, len(src_idx) - num_train]) 221 | if args.validation == 'last': # for MI 222 | num_all = args.trial 223 | num_train = int(0.9 * num_all) 224 | id_train = np.array(src_idx).reshape(-1, num_all)[:, :num_train].reshape(1, -1).flatten() 225 | id_val = np.array(src_idx).reshape(-1, num_all)[:, num_train:].reshape(1, -1).flatten() 226 | 227 | data_src = Data.TensorDataset(Xs, Ys) 228 | source_tr = Data.TensorDataset(Xs[id_train, :], Ys[id_train]) 229 | source_te = Data.TensorDataset(Xs[id_val, :], Ys[id_val]) 230 | if Xt != None: 231 | data_tar = Data.TensorDataset(Xt, Yt) 232 | 233 | # for DNN 234 | if Xs != None: 235 | dset_loaders["source_tr"] = Data.DataLoader(source_tr, batch_size=train_bs, shuffle=True, drop_last=True) 236 | dset_loaders["source_te"] = Data.DataLoader(source_te, batch_size=train_bs, shuffle=False, drop_last=False) 237 | 238 | # for DAN/DANN/CDAN/MCC 239 | if Xs != None: 240 | dset_loaders["source"] = Data.DataLoader(data_src, batch_size=train_bs, shuffle=True, drop_last=True) 241 | if Xt != None: 242 | dset_loaders["target"] = Data.DataLoader(data_tar, batch_size=train_bs, shuffle=True, drop_last=True) 243 | 244 | # for generating feature 245 | if Xs != None: 246 | dset_loaders["Source"] = Data.DataLoader(data_src, batch_size=train_bs * 3, shuffle=False, drop_last=False) 247 | if Xt != None: 248 | dset_loaders["Target"] = Data.DataLoader(data_tar, batch_size=train_bs * 3, shuffle=False, drop_last=False) 249 | 250 | return dset_loaders 251 | -------------------------------------------------------------------------------- /MDMAML.py: -------------------------------------------------------------------------------- 1 | import learn2learn as l2l 2 | import numpy as np 3 | import torch 4 | 5 | from models.EEGNet import EEGNet_features1, EEGNet_latter 6 | from models.ShallowConvNet import ShallowConvNetFeatures, ShallowConvNetFeaturesReduced, ShallowConvNetClassifier 7 | from EEG_cross_subject_loader import EEG_loader 8 | 9 | import random 10 | 11 | 12 | def main( 13 | test_subj=None, 14 | ways=None, 15 | shots=None, 16 | meta_lr=None, 17 | fast_lr=None, 18 | meta_batch_size=None, 19 | adaptation_steps=None, 20 | num_iterations=None, 21 | cuda=None, 22 | seed=None, 23 | model_name=None, 24 | dataset=None, 25 | se=None, 26 | ): 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | device = torch.device('cpu') 31 | if cuda: 32 | torch.cuda.manual_seed(seed) 33 | device = torch.device('cuda:1') 34 | print('using cuda...') 35 | 36 | data = EEG_loader(test_subj=test_subj, dataset=dataset) 37 | train_x_arr = data.train_x 38 | train_y_arr = data.train_y 39 | train_x_arr_tmp = [] 40 | train_y_arr_tmp = [] 41 | for train_x, train_y in zip(train_x_arr, train_y_arr): 42 | train_x_arr_tmp.append(train_x) 43 | train_y_arr_tmp.append(train_y) 44 | 45 | l2l_train_tasks_arr = [] 46 | 47 | for i in range(len(train_x_arr_tmp)): 48 | tensor_train_x, tensor_train_y = torch.from_numpy(train_x_arr_tmp[i]).unsqueeze_(3).to( 49 | torch.float32), torch.squeeze(torch.from_numpy(train_y_arr_tmp[i]), 1).to(torch.long) 50 | train_torch_dataset = torch.utils.data.TensorDataset(tensor_train_x, tensor_train_y) 51 | train_loader = torch.utils.data.DataLoader(train_torch_dataset, batch_size=len(train_torch_dataset)) 52 | l2l_train_tasks_arr.append(train_loader) 53 | del train_x_arr, train_y_arr, train_x_arr_tmp, train_y_arr_tmp, data, train_torch_dataset, tensor_train_x, tensor_train_y 54 | 55 | if model_name == 'ShallowConvNet': 56 | if dataset == 'MI1': 57 | model1 = ShallowConvNetFeatures(4, 22, 16640) 58 | model2 = ShallowConvNetClassifier(4, 22, 16640) 59 | if dataset == 'MI2': 60 | model1 = ShallowConvNetFeatures(2, 15, 26520) 61 | model2 = ShallowConvNetClassifier(2, 15, 26520) 62 | if dataset == 'ERP1': 63 | model1 = ShallowConvNetFeaturesReduced(2, 16, 6760) 64 | model2 = ShallowConvNetClassifier(2, 16, 6760) 65 | if dataset == 'ERP2': 66 | model1 = ShallowConvNetFeatures(2, 56, 17160) 67 | model2 = ShallowConvNetClassifier(2, 56, 17160) 68 | elif model_name == 'EEGNet': 69 | if dataset == 'MI1': 70 | model1 = EEGNet_features1(22, 256, 4) 71 | model2 = EEGNet_latter(22, 256, 4) 72 | if dataset == 'MI2': 73 | model1 = EEGNet_features1(15, 384, 2) 74 | model2 = EEGNet_latter(15, 384, 2) 75 | if dataset == 'ERP1': 76 | model1 = EEGNet_features1(16, 32, 2) 77 | model2 = EEGNet_latter(16, 32, 2) 78 | if dataset == 'ERP2': 79 | model1 = EEGNet_features1(56, 256, 2) 80 | model2 = EEGNet_latter(56, 256, 2) 81 | 82 | # subject number 83 | k = -1 84 | if dataset == 'MI1': 85 | k = 9 86 | if dataset == 'MI2': 87 | k = 14 88 | if dataset == 'ERP1': 89 | k = 10 90 | if dataset == 'ERP2': 91 | k = 16 92 | 93 | if dataset == 'ERP1': 94 | class_weight = torch.tensor([1., 4.99], dtype=torch.float32).to(device) 95 | criterion = torch.nn.CrossEntropyLoss(weight=class_weight) 96 | elif dataset == 'ERP2': 97 | class_weight = torch.tensor([1., 2.42], dtype=torch.float32).to(device) 98 | criterion = torch.nn.CrossEntropyLoss(weight=class_weight) 99 | else: 100 | criterion = torch.nn.CrossEntropyLoss() 101 | 102 | former = model1 103 | latter = model2 104 | 105 | former.to(device) 106 | 107 | latter = l2l.algorithms.MAML(latter, lr=fast_lr, first_order=True, allow_nograd=True) 108 | latter.to(device) 109 | 110 | all_parameters = list(former.parameters()) + list(latter.parameters()) 111 | opt = torch.optim.Adam(all_parameters, lr=meta_lr) 112 | loss = criterion 113 | 114 | print('start training...') 115 | for iteration in range(1, num_iterations + 1): 116 | 117 | opt.zero_grad() 118 | 119 | index_arr = [] 120 | for i in range(len(l2l_train_tasks_arr)): 121 | target_domain_id = random.choice(np.arange(len(l2l_train_tasks_arr))) 122 | while i == target_domain_id: 123 | target_domain_id = random.choice(np.arange(len(l2l_train_tasks_arr))) 124 | index_arr.append(target_domain_id) 125 | print(index_arr) 126 | 127 | meta_train_error = 0.0 128 | meta_train_accuracy = 0.0 129 | cnt = 0 130 | 131 | for i in range(len(l2l_train_tasks_arr)): 132 | 133 | train_task = next(iter(l2l_train_tasks_arr[i])) 134 | val_task = next(iter(l2l_train_tasks_arr[index_arr[i]])) 135 | 136 | val_data, val_labels = val_task 137 | val_data, val_labels = val_data.to(device), val_labels.to(device) 138 | former.eval() 139 | latter_cloned = latter.clone() 140 | x = former(val_data) 141 | valid_error = loss(latter_cloned(x), val_labels) 142 | 143 | latter_cloned = latter.clone() 144 | former.train() 145 | train_error, train_accuracy = fast_adapt(train_task, 146 | val_task, 147 | latter_cloned, 148 | former, 149 | loss, 150 | adaptation_steps, 151 | shots, 152 | ways, 153 | device) 154 | 155 | # backwardprop update if no negative transfer: target loss does not decrease 156 | if train_error < valid_error: 157 | train_error.backward() 158 | cnt += 1 159 | meta_train_error += train_error.item() 160 | meta_train_accuracy += train_accuracy.item() 161 | 162 | print('Iteration', iteration) 163 | 164 | if cnt == 0: 165 | print('no match this epoch') 166 | continue 167 | 168 | print('Meta Train Error', meta_train_error / cnt) 169 | print('Meta Train Accuracy', meta_train_accuracy / cnt) 170 | 171 | s = dataset + '_test_subj_' + str(test_subj) + '_shots_' + str(shots) + '_meta_lr_' + str( 172 | meta_lr) + '_fast_lr_' + \ 173 | str(fast_lr) + '_meta_batch_size_' + str(meta_batch_size) + '_adaptation_steps_' + str( 174 | adaptation_steps) + str(model_name) 175 | 176 | if iteration % 50 == 0: 177 | print('saving model...') 178 | 179 | torch.save(former, 180 | './runs/' + str(dataset) + '/mdmaml_model1_' + s + '_num_iterations_' + str( 181 | iteration) + 'seed' + str( 182 | se) + '.pt') 183 | 184 | torch.save(latter, 185 | './runs/' + str(dataset) + '/mdmaml_model2_' + s + '_num_iterations_' + str( 186 | iteration) + 'seed' + str( 187 | se) + '.pt') 188 | 189 | # Average the accumulated gradients and optimize 190 | for p in all_parameters: 191 | if p.grad is None: 192 | continue 193 | p.grad.data.mul_(1.0 / cnt) 194 | opt.step() 195 | 196 | 197 | def accuracy(predictions, targets): 198 | predictions = predictions.argmax(dim=1).view(targets.shape) 199 | return (predictions == targets).sum().float() / targets.size(0) 200 | 201 | 202 | def fast_adapt(adaptation_batch, evaluation_batch, latter, features1, loss, adaptation_steps, shots, ways, device): 203 | adaptation_data, adaptation_labels = adaptation_batch 204 | evaluation_data, evaluation_labels = evaluation_batch 205 | adaptation_data, adaptation_labels = adaptation_data.to(device), adaptation_labels.to(device) 206 | evaluation_data, evaluation_labels = evaluation_data.to(device), evaluation_labels.to(device) 207 | 208 | adaptation_data = features1(adaptation_data) 209 | 210 | # Adapt the model 211 | for step in range(adaptation_steps): 212 | train_error = loss(latter(adaptation_data), adaptation_labels) 213 | latter.adapt(train_error) 214 | 215 | evaluation_data = features1(evaluation_data) 216 | 217 | # Evaluate the adapted model 218 | predictions = latter(evaluation_data) 219 | valid_error = loss(predictions, evaluation_labels) 220 | 221 | valid_accuracy = accuracy(predictions, evaluation_labels) 222 | return valid_error, valid_accuracy 223 | 224 | 225 | if __name__ == '__main__': 226 | 227 | meta_lr = 0.001 228 | fast_lr = 0.001 229 | for model_name in ['EEGNet', 'ShallowConvNet']: 230 | for dataset in ['MI1', 'MI2', 'ERP1', 'ERP2']: 231 | if dataset == 'MI1': 232 | subj_num = 9 233 | elif dataset == 'MI2': 234 | subj_num = 14 235 | elif dataset == 'ERP1': 236 | subj_num = 10 237 | elif dataset == 'ERP2': 238 | subj_num = 16 239 | 240 | if dataset == 'MI1': 241 | ways = 4 242 | else: 243 | ways = 2 244 | 245 | for test_subj in range(0, subj_num): 246 | for seed in range(0, 10): 247 | print('MDMAML', dataset, model_name) 248 | print('subj', test_subj, 'seed', seed) 249 | main(test_subj=test_subj, 250 | ways=ways, 251 | shots=1, 252 | meta_lr=meta_lr, 253 | fast_lr=fast_lr, 254 | meta_batch_size=1, 255 | adaptation_steps=1, 256 | num_iterations=500, 257 | cuda=True, 258 | seed=42, 259 | model_name=model_name, 260 | dataset=dataset, 261 | se=seed, 262 | ) 263 | -------------------------------------------------------------------------------- /utils/func_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/3/17 14:24 3 | # @Author : wenzhang 4 | # @File : func_utils.py 5 | import os.path as osp 6 | import os 7 | import numpy as np 8 | import random 9 | import torch as tr 10 | import torch.nn as nn 11 | from scipy.spatial.distance import cdist 12 | from utils.utils import lr_scheduler 13 | 14 | 15 | def obtain_label_decision(loader, netF, netC): 16 | start_test = True 17 | with tr.no_grad(): 18 | iter_test = iter(loader) 19 | for _ in range(len(loader)): 20 | data = iter_test.next() 21 | inputs, labels = data[0].cuda(), data[1] 22 | feas = netF(inputs.float()) 23 | outputs = netC(feas) 24 | if start_test: 25 | all_fea = feas.float().cpu() 26 | all_output = outputs.float().cpu() 27 | all_label = labels.float() 28 | start_test = False 29 | else: 30 | all_fea = tr.cat((all_fea, feas.float().cpu()), 0) 31 | all_output = tr.cat((all_output, outputs.float().cpu()), 0) 32 | all_label = tr.cat((all_label, labels.float()), 0) 33 | 34 | all_output = nn.Softmax(dim=1)(all_output) 35 | _, predict = tr.max(all_output, 1) 36 | 37 | all_fea = tr.cat((all_fea, tr.ones(all_fea.size(0), 1)), 1) 38 | all_fea = (all_fea.t() / tr.norm(all_fea, p=2, dim=1)).t() 39 | all_fea = all_fea.float().cpu().numpy() 40 | 41 | K = all_output.size(1) 42 | aff = all_output.float().cpu().numpy() 43 | initc = aff.transpose().dot(all_fea) 44 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 45 | 46 | dd = cdist(all_fea, initc, 'cosine') 47 | pred_label = dd.argmin(axis=1) 48 | 49 | for round in range(1): # SSL 50 | aff = np.eye(K)[pred_label] 51 | initc = aff.transpose().dot(all_fea) 52 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 53 | 54 | return initc, all_fea 55 | 56 | 57 | def obtain_label_shot(loader, netF, netC): 58 | start_test = True 59 | with tr.no_grad(): 60 | iter_test = iter(loader) 61 | for _ in range(len(loader)): 62 | data = iter_test.next() 63 | inputs = data[0] 64 | labels = data[1] 65 | inputs = inputs.cuda() 66 | feas = netF(inputs) 67 | outputs = netC(feas) 68 | if start_test: 69 | all_fea = feas.float().cpu() 70 | all_output = outputs.float().cpu() 71 | all_label = labels.float() 72 | start_test = False 73 | else: 74 | all_fea = tr.cat((all_fea, feas.float().cpu()), 0) 75 | all_output = tr.cat((all_output, outputs.float().cpu()), 0) 76 | all_label = tr.cat((all_label, labels.float()), 0) 77 | 78 | all_output = nn.Softmax(dim=1)(all_output) 79 | # print(all_output.shape) 80 | # ent = tr.sum(-all_output * tr.log(all_output + args.epsilon), dim=1) 81 | # unknown_weight = 1 - ent / np.log(args.class_num) 82 | _, predict = tr.max(all_output, 1) 83 | 84 | accuracy = tr.sum(tr.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 85 | 86 | all_fea = tr.cat((all_fea, tr.ones(all_fea.size(0), 1)), 1) 87 | all_fea = (all_fea.t() / tr.norm(all_fea, p=2, dim=1)).t() 88 | 89 | all_fea = all_fea.float().cpu().numpy() 90 | K = all_output.size(1) 91 | aff = all_output.float().cpu().numpy() 92 | 93 | initc = aff.transpose().dot(all_fea) 94 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 95 | 96 | cls_count = np.eye(K)[predict].sum(axis=0) 97 | labelset = np.where(cls_count > 0) 98 | labelset = labelset[0] 99 | 100 | dd = cdist(all_fea, initc[labelset], 'cosine') 101 | pred_label = dd.argmin(axis=1) 102 | pred_label = labelset[pred_label] 103 | # acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 104 | # log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100) 105 | # print(log_str+'\n') 106 | 107 | for round in range(1): 108 | aff = np.eye(K)[pred_label] 109 | initc = aff.transpose().dot(all_fea) 110 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 111 | dd = cdist(all_fea, initc[labelset], 'cosine') 112 | pred_label = dd.argmin(axis=1) 113 | pred_label = labelset[pred_label] 114 | 115 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 116 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100) 117 | # print(log_str + '\n') 118 | 119 | return pred_label.astype('int'), dd 120 | 121 | 122 | def update_decision(dset_loaders, netF_list, netC_list, netG_list, optimizer, info_loss, args): 123 | max_iter = len(dset_loaders["target"]) 124 | num_src = len(args.src) 125 | 126 | iter_num = 0 127 | while iter_num < max_iter: 128 | iter_target = iter(dset_loaders["target"]) 129 | inputs_target, _, tar_idx = iter_target.next() 130 | if inputs_target.size(0) == 1: 131 | continue 132 | inputs_target = inputs_target.cuda() 133 | 134 | # 每10个epoch才进行一次pseudo labels增强 135 | # 这样改仅从75.54到75.61,变化很小 136 | interval_iter = 10 137 | if iter_num % interval_iter == 0 and args.cls_par > 0: 138 | initc = [] 139 | all_feas = [] 140 | for i in range(num_src): 141 | netF_list[i].eval() 142 | temp1, temp2 = obtain_label_decision(dset_loaders['Target'], netF_list[i], netC_list[i]) 143 | temp1 = tr.from_numpy(temp1).cuda() 144 | temp2 = tr.from_numpy(temp2).cuda() 145 | initc.append(temp1) 146 | all_feas.append(temp2) 147 | netF_list[i].train() 148 | 149 | iter_num += 1 150 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 151 | 152 | ################################################################################### 153 | # output, domain weight, weighted output 154 | if args.use_weight: 155 | weights_all = tr.ones(inputs_target.shape[0], len(args.src)) 156 | tmp_output = tr.zeros(len(args.src), inputs_target.shape[0], args.class_num) 157 | for i in range(len(args.src)): 158 | tmp_output[i] = netC_list[i](netF_list[i](inputs_target)) 159 | weights_all[:, i] = netG_list[i](tmp_output[i]).squeeze() 160 | z = tr.sum(weights_all, dim=1) + 1e-16 161 | weights_all = tr.transpose(tr.transpose(weights_all, 0, 1) / z, 0, 1) 162 | weights_domain = tr.sum(weights_all, dim=0) / tr.sum(weights_all) 163 | domain_weight = weights_domain.reshape([1, num_src, 1]).cuda() 164 | else: 165 | domain_weight = tr.Tensor([1 / num_src] * num_src).reshape([1, num_src, 1]).cuda() 166 | weights_domain = np.round(tr.squeeze(domain_weight, 0).t().flatten().cpu().detach(), 3) 167 | # print(type(domain_weight), type(weights_domain)) # [1, 3, 1], [3] 168 | 169 | outputs_all = tr.cat([netC_list[i](netF_list[i](inputs_target)).unsqueeze(1) for i in range(num_src)], 1).cuda() 170 | preds = tr.softmax(outputs_all, dim=2) 171 | outputs_all_w = (preds * domain_weight).sum(dim=1).cuda() 172 | # print(outputs_all.shape, preds.shape, domain_weight.shape, outputs_all_w.shape) 173 | # [4, 8, 4], [4, 8, 4], [1, 8, 1], [4, 4] 174 | ################################################################################### 175 | 176 | # self pseudo label loss 177 | if args.cls_par > 0: 178 | initc_ = tr.zeros(initc[0].size()).cuda() 179 | temp = all_feas[0] 180 | all_feas_ = tr.zeros(temp[tar_idx, :].size()).cuda() 181 | for i in range(num_src): 182 | initc_ = initc_ + weights_domain[i] * initc[i].float() 183 | src_fea = all_feas[i] 184 | all_feas_ = all_feas_ + weights_domain[i] * src_fea[tar_idx, :] 185 | dd = tr.cdist(all_feas_.float(), initc_.float(), p=2) 186 | pred_label = dd.argmin(dim=1) 187 | pred = pred_label.int().long() 188 | clf_loss = nn.CrossEntropyLoss()(outputs_all_w, pred) 189 | else: 190 | clf_loss = tr.tensor(0.0).cuda() 191 | 192 | # raw decision 193 | im_loss = info_loss(outputs_all_w, args.epsilon) 194 | loss_all = args.cls_par * clf_loss + args.ent_par * im_loss 195 | 196 | optimizer.zero_grad() 197 | loss_all.backward() 198 | optimizer.step() 199 | 200 | 201 | def update_shot_ens(dset_loaders, netF, netC, optimizer, info_loss, args): 202 | # 分别对每个源模型和Xt训练,获得迁移之后的模型,然后集成 203 | max_iter = len(dset_loaders["target"]) 204 | 205 | iter_num = 0 206 | while iter_num < max_iter: 207 | iter_target = iter(dset_loaders["target"]) 208 | inputs_target, _, tar_idx = iter_target.next() 209 | if inputs_target.size(0) == 1: 210 | continue 211 | inputs_target = inputs_target.cuda() 212 | 213 | # 每10个epoch才进行一次pseudo labels增强 214 | interval_iter = 10 215 | if iter_num % interval_iter == 0 and args.cls_par > 0: 216 | netF.eval() 217 | mem_label, dd = obtain_label_shot(dset_loaders['Target'], netF, netC) 218 | mem_label = tr.from_numpy(mem_label).cuda() 219 | netF.train() 220 | 221 | iter_num += 1 222 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 223 | 224 | feas_target = netF(inputs_target.cuda()) 225 | outputs_target = netC(feas_target) 226 | 227 | # 构建 shot loss 228 | if args.cls_par > 0: # 控制需不需要加自监督loss,默认0.3 229 | pred = mem_label[tar_idx].long() 230 | clf_loss = nn.CrossEntropyLoss()(outputs_target, pred) 231 | else: 232 | clf_loss = tr.tensor(0.0).cuda() 233 | 234 | # IM loss on the weighted output 235 | im_loss = info_loss(outputs_target, args.epsilon) 236 | loss_all = args.cls_par * clf_loss + args.ent_par * im_loss 237 | 238 | optimizer.zero_grad() 239 | loss_all.backward() 240 | optimizer.step() 241 | 242 | 243 | def knowledge_vote(preds_softmax, confidence_gate, num_classes): 244 | max_p, max_p_class = preds_softmax.max(2) 245 | max_conf, _ = max_p.max(1) 246 | max_p_mask = (max_p > confidence_gate).float().cuda() 247 | preds_vote = tr.zeros(preds_softmax.size(0), preds_softmax.size(2)).cuda() 248 | for batch_idx, (p, p_class, p_mask) in enumerate(zip(max_p, max_p_class, max_p_mask)): 249 | if tr.sum(p_mask) > 0: 250 | p = p * p_mask 251 | for source_idx, source_class in enumerate(p_class): 252 | preds_vote[batch_idx, source_class] += p[source_idx] 253 | _, preds_vote = preds_vote.max(1) 254 | preds_vote = tr.zeros(preds_vote.size(0), num_classes).cuda().scatter_(1, preds_vote.view(-1, 1), 1) 255 | 256 | return preds_vote 257 | 258 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import TensorDataset, DataLoader 4 | 5 | from sklearn.metrics import accuracy_score, roc_auc_score 6 | 7 | from models.EEGNet import EEGNet 8 | from models.ShallowConvNet import ShallowConvNet, ShallowConvNetReduced 9 | from EEG_cross_subject_loader import EEG_loader 10 | 11 | import random 12 | import sys 13 | import time 14 | 15 | 16 | def main( 17 | test_subj=None, 18 | learning_rate=None, 19 | adaptation_iterations=None, 20 | cuda=None, 21 | seed_num=None, 22 | shots=None, 23 | test_path=None, 24 | dataset=None, 25 | model_name=None, 26 | roc_auc=0, 27 | ): 28 | 29 | random.seed(1) 30 | np.random.seed(1) 31 | torch.manual_seed(1) 32 | device = torch.device('cpu') 33 | if cuda: 34 | torch.cuda.manual_seed(1) 35 | device = torch.device('cuda:5') 36 | #print('using cuda...') 37 | 38 | 39 | if model_name == 'ShallowConvNet': 40 | if dataset == 'MI1': 41 | model = ShallowConvNet(4, 22, 16640) 42 | if dataset == 'MI2': 43 | model = ShallowConvNet(2, 15, 26520) 44 | if dataset == 'ERP1': 45 | model = ShallowConvNetReduced(2, 16, 6760) 46 | if dataset == 'ERP2': 47 | model = ShallowConvNet(2, 56, 17160) 48 | elif model_name == 'EEGNet': 49 | if dataset == 'MI1': 50 | model = EEGNet(22, 256, 4) 51 | if dataset == 'MI2': 52 | model = EEGNet(15, 384, 2) 53 | if dataset == 'ERP1': 54 | model = EEGNet(16, 32, 2) 55 | if dataset == 'ERP2': 56 | model = EEGNet(56, 256, 2) 57 | 58 | if dataset == 'MI2': 59 | class_num = 4 60 | else: 61 | class_num = 2 62 | 63 | data = EEG_loader(test_subj=test_subj, dataset=dataset) 64 | 65 | if shots == 0: 66 | 67 | test_x, test_y = data.test_x, data.test_y 68 | tensor_test_x, tensor_test_y = torch.from_numpy(test_x).unsqueeze_(3).to( 69 | torch.float32), torch.squeeze( 70 | torch.from_numpy(test_y), 1).to(torch.long) 71 | 72 | test_dataset = TensorDataset(tensor_test_x, tensor_test_y) 73 | test_loader = DataLoader(test_dataset) 74 | 75 | #model.load_state_dict(torch.load(test_path)) 76 | model = torch.load(test_path) 77 | model.to(device) 78 | model.eval() 79 | 80 | y_true = [] 81 | y_pred = [] 82 | with torch.no_grad(): 83 | correct = 0 84 | total = 0 85 | for x, y in test_loader: 86 | x = x.to(device) 87 | y = y.to(device) 88 | outputs = model(x) 89 | _, predicted = torch.max(outputs.data, 1) 90 | total += y.size(0) 91 | 92 | y_true.append(y.item()) 93 | y_pred.append(predicted.item()) 94 | 95 | # print('Accuracy of the network on the test subject : {} %'.format(100 * correct / total)) 96 | if roc_auc == '1': 97 | print('using roc_auc') 98 | out = roc_auc_score(y_true, y_pred) 99 | else: 100 | out = accuracy_score(y_true, y_pred) 101 | 102 | return round(out,3), 0 103 | 104 | accuracy_arr = [] 105 | for seed in range(seed_num): 106 | target_x, target_y = data.test_x, data.test_y 107 | # print(target_x.shape, target_y.shape) 108 | # input('') 109 | 110 | np.random.seed(seed) 111 | idx = list(range(len(target_y))) 112 | np.random.shuffle(idx) 113 | target_x = target_x[idx] 114 | target_y = target_y[idx] 115 | 116 | calib_ind = np.ones(class_num * 2) * shots 117 | 118 | train_x = [] 119 | train_y = [] 120 | train_index = [] 121 | for j in range(class_num): 122 | for i in range(len(target_y)): 123 | # print(target_y[i, 0]) 124 | if target_y[i, 0] == j: 125 | train_x.append(target_x[i]) 126 | train_y.append(target_y[i]) 127 | train_index.append([i]) 128 | calib_ind[j] -= 1 129 | # print(calib_ind[j]) 130 | if calib_ind[j] == 0.0: 131 | break 132 | train_x = np.array(train_x) 133 | train_y = np.array(train_y) 134 | #print('calibration labels: ', train_y) 135 | # print(train_x.shape, train_y.shape) 136 | # print(train_index) 137 | 138 | test_x = np.delete(target_x, train_index, axis=0) 139 | test_y = np.delete(target_y, train_index, axis=0) 140 | # print(test_x.shape, test_y.shape) 141 | 142 | tensor_train_x, tensor_train_y = torch.from_numpy(train_x).unsqueeze_(3).to( 143 | torch.float32), torch.squeeze( 144 | torch.from_numpy(train_y), 1).to(torch.long) 145 | 146 | train_dataset = TensorDataset(tensor_train_x, tensor_train_y) 147 | train_loader = DataLoader(train_dataset) 148 | 149 | tensor_test_x, tensor_test_y = torch.from_numpy(test_x).unsqueeze_(3).to( 150 | torch.float32), torch.squeeze( 151 | torch.from_numpy(test_y), 1).to(torch.long) 152 | 153 | test_dataset = TensorDataset(tensor_test_x, tensor_test_y) 154 | test_loader = DataLoader(test_dataset) 155 | 156 | # print(train_x.shape, test_x.shape) 157 | 158 | model.to(device) 159 | opt = torch.optim.Adam(model.parameters(), lr=learning_rate) 160 | criterion = torch.nn.CrossEntropyLoss() 161 | 162 | model.load_state_dict(torch.load(test_path)) 163 | 164 | # Train the model 165 | for epoch in range(adaptation_iterations): 166 | # print('epoch:', epoch + 1) 167 | #total_loss = 0 168 | #cnt = 0 169 | for i, (x, y) in enumerate(train_loader): 170 | # Forward pass 171 | x = x.to(device) 172 | y = y.to(device) 173 | 174 | outputs = model(x) 175 | loss = criterion(outputs, y) 176 | #total_loss += loss 177 | #cnt += 1 178 | 179 | # Backward and optimize 180 | opt.zero_grad() 181 | loss.backward() 182 | opt.step() 183 | #out_loss = total_loss / cnt 184 | 185 | # print('Epoch [{}/{}], , Loss: {:.4f}' 186 | # .format(epoch + 1, epoch, out_loss)) 187 | 188 | # Test the model 189 | model.eval() 190 | 191 | y_true = [] 192 | y_pred = [] 193 | with torch.no_grad(): 194 | correct = 0 195 | total = 0 196 | for x, y in test_loader: 197 | x = x.to(device) 198 | y = y.to(device) 199 | outputs = model(x) 200 | _, predicted = torch.max(outputs.data, 1) 201 | total += y.size(0) 202 | y_true.append(y.item()) 203 | y_pred.append(predicted.item()) 204 | 205 | # print('Accuracy of the network on the test subject : {} %'.format(100 * correct / total)) 206 | if roc_auc == '1': 207 | print('using roc_auc') 208 | out = roc_auc_score(y_true, y_pred) 209 | else: 210 | out = accuracy_score(y_true, y_pred) 211 | accuracy_arr.append(round(out,3)) 212 | #del data, target_x, target_y, tensor_train_x, tensor_train_y, tensor_test_x, tensor_test_y, train_dataset, test_dataset 213 | #print(accuracy_arr, round(np.average(accuracy_arr), 3), round(np.std(accuracy_arr), 3)) 214 | #input('') 215 | return round(np.average(accuracy_arr), 3), round(np.std(accuracy_arr), 3) 216 | 217 | 218 | if __name__ == '__main__': 219 | model_name = 'EEGNet' 220 | dataset = 'MI1' 221 | 222 | seed_num = 10 # test times of random calibration data 223 | learning_rate = 0.001 # normal learning rate 224 | meta_learning_rate = 0.001 # meta learning rate 225 | 226 | #shots = 1 227 | #adaptation_iterations = 2 228 | #test_load_epoch = 10 229 | 230 | mode = str(sys.argv[1]) # mode name 231 | shots = int(sys.argv[2]) # test subject calibration data of shots number (shots * classes) 232 | adaptation_iterations = int(sys.argv[3]) # test subject calibration data adaptation iterations/steps 233 | test_load_epoch = str(sys.argv[4]) # file of loaded saved parameters epoch number 234 | train_shots = str(sys.argv[5]) # file of loaded saved parameters shots number 235 | train_adaption_steps = str(sys.argv[6]) # file of loaded saved parameters adaptation steps number 236 | roc_auc = str(sys.argv[7]) # use roc_auc or not, 1 for yes, 0 for no 237 | 238 | avg_arr = [] 239 | std_arr = [] 240 | all_arr = [] 241 | for i in range(0, 14): 242 | out_acc_arr = [] 243 | out_std_arr = [] 244 | for s in range(0, 1): 245 | print('subj', i, 'seed', s) 246 | if mode == 'base' or mode == 'finetune': 247 | # shots = 0 248 | path = './runs/' + str(dataset) + '/baseline_' + str(model_name) + str(dataset) + '_seed' + str( 249 | s) + '_test_subj_' + str( 250 | i) + '_epoch' + str(test_load_epoch) + '.pt' 251 | elif mode == 'maml': 252 | path = './runs/' + str(dataset) + '/' + 'maml_' + str(dataset) + '_test_subj_' + str( 253 | i) + '_shots_' + str( 254 | train_shots) + '_meta_lr_' + str(meta_learning_rate) + '_fast_lr_' + str( 255 | learning_rate) + '_meta_batch_size_1_adaptation_steps_' + str(train_adaption_steps) + str( 256 | model_name) + '_num_iterations_' + str(test_load_epoch) + 'seed' + str(s) + '.pth' 257 | elif mode == 'mdmaml': 258 | path = './runs/' + str(dataset) + '/' + 'mdmaml_' + str(dataset) + '_test_subj_' + str( 259 | i) + '_shots_' + str( 260 | train_shots) + '_meta_lr_' + str(meta_learning_rate) + '_fast_lr_' + str( 261 | learning_rate) + '_meta_batch_size_1_adaptation_steps_' + str(train_adaption_steps) + str( 262 | model_name) + '_num_iterations_'+ str(test_load_epoch) + 'seed' + str(s) + '.pth' 263 | #'cdmaml_MI1_test_subj_1_shots_25_meta_lr_0.001_fast_lr_0.001_meta_batch_size_1_adaptation_steps_1EEGNetwithload_num_iterations_5seed0' 264 | elif mode == 'cdmaml+': 265 | path = './runs/' + str(dataset) + 'cdmaml+/' + 'cdmaml+_' + str(dataset) + '_test_subj_' + str( 266 | i) + '_shots_' + str( 267 | train_shots) + '_meta_lr_' + str(meta_learning_rate) + '_fast_lr_' + str( 268 | learning_rate) + '_meta_batch_size_1_adaptation_steps_1' + str( 269 | model_name) + '_num_iterations_' + str(test_load_epoch) + 'seed' + str(s) + '.pth' 270 | elif mode == 'cdmaml-': 271 | path = './runs/' + str(dataset) + 'cdmaml-/' + 'cdmaml-_' + str(dataset) + '_test_subj_' + str( 272 | i) + '_shots_' + str( 273 | train_shots) + '_meta_lr_' + str(meta_learning_rate) + '_fast_lr_' + str( 274 | learning_rate) + '_meta_batch_size_1_adaptation_steps_1' + str( 275 | model_name) + 'withload_num_iterations_' + str(test_load_epoch) + 'seed' + str(s) + '.pth' 276 | acc, std = main( 277 | test_subj=i, 278 | learning_rate=learning_rate, 279 | adaptation_iterations=adaptation_iterations, 280 | cuda=False, 281 | seed_num=seed_num, 282 | shots=shots, 283 | test_path=path, 284 | dataset=dataset, 285 | model_name=model_name, 286 | roc_auc=roc_auc, 287 | ) 288 | print('score, std:', acc, std) 289 | out_acc_arr.append(round(acc, 3)) 290 | out_std_arr.append(round(std, 3)) 291 | print(out_acc_arr) 292 | all_arr.append(out_acc_arr) 293 | avg_arr.append(round(np.average(out_acc_arr), 3)) 294 | std_arr.append(round(np.std(out_acc_arr), 5)) 295 | 296 | total_avg = round(np.average(avg_arr), 5) 297 | total_std = round(np.std(np.average(all_arr, axis=0)), 5) 298 | print('#' * 32) 299 | print(dataset, model_name, mode, '\nshots:', shots, '; iter:', adaptation_iterations, '; loaded_epoch:', 300 | test_load_epoch) 301 | print('avg_arr:', avg_arr) 302 | print('std_arr:', std_arr) 303 | print('total_avg:', total_avg) 304 | print('total_std:', total_std) 305 | -------------------------------------------------------------------------------- /ProtoNets.py: -------------------------------------------------------------------------------- 1 | import learn2learn as l2l 2 | import numpy as np 3 | import torch 4 | 5 | from learn2learn.data.transforms import NWays, KShots, LoadData 6 | from torch.utils.data import TensorDataset, DataLoader 7 | from sklearn.metrics import accuracy_score, roc_auc_score 8 | 9 | from models.EEGNet import EEGNet_features 10 | from models.ShallowConvNet import ShallowConvNetFeatures, ShallowConvNetFeaturesReduced 11 | from EEG_cross_subject_loader import EEG_loader 12 | 13 | import random 14 | 15 | 16 | def main( 17 | test_subj=None, 18 | ways=None, 19 | shots=None, 20 | lr=None, 21 | num_iterations=None, 22 | cuda=None, 23 | seed=None, 24 | model_name=None, 25 | dataset=None, 26 | se=None, 27 | test=True, 28 | path=None, 29 | ): 30 | 31 | random.seed(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | device = torch.device('cpu') 35 | if cuda: 36 | torch.cuda.manual_seed(seed) 37 | device = torch.device('cuda:4') 38 | print('using cuda...') 39 | 40 | data = EEG_loader(test_subj=test_subj, dataset=dataset) 41 | if not test: 42 | train_x_arr = data.train_x 43 | train_y_arr = data.train_y 44 | train_x_arr_tmp = [] 45 | train_y_arr_tmp = [] 46 | for train_x, train_y in zip(train_x_arr, train_y_arr): 47 | train_x_arr_tmp.append(train_x) 48 | train_y_arr_tmp.append(train_y) 49 | 50 | l2l_train_tasks_arr = [] 51 | meta_batch_size = 8 52 | for i in range(len(train_x_arr_tmp)): 53 | tensor_train_x, tensor_train_y = torch.from_numpy(train_x_arr_tmp[i]).unsqueeze_(3).to( 54 | torch.float32), torch.squeeze(torch.from_numpy(train_y_arr_tmp[i]), 1).to(torch.long) 55 | train_torch_dataset = TensorDataset(tensor_train_x, tensor_train_y) 56 | train_dataset = l2l.data.MetaDataset(train_torch_dataset) 57 | train_tasks = l2l.data.TaskDataset(train_dataset, 58 | task_transforms=[ 59 | NWays(train_dataset, n=ways), 60 | KShots(train_dataset, k=2 * shots), 61 | LoadData(train_dataset), 62 | ], 63 | num_tasks=meta_batch_size) 64 | l2l_train_tasks_arr.append(train_tasks) 65 | del train_x_arr, train_y_arr, train_x_arr_tmp, train_y_arr_tmp, data, train_dataset, train_torch_dataset, tensor_train_x, tensor_train_y 66 | else: 67 | test_x, test_y = data.test_x, data.test_y 68 | 69 | tensor_test_x, tensor_test_y = torch.from_numpy(test_x).unsqueeze_(3).to( 70 | torch.float32), torch.squeeze( 71 | torch.from_numpy(test_y), 1).to(torch.long) 72 | 73 | test_dataset = TensorDataset(tensor_test_x, tensor_test_y) 74 | test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, drop_last=True) 75 | 76 | if model_name == 'ShallowConvNet': 77 | if dataset == 'MI1': 78 | model = ShallowConvNetFeatures(4, 22, 16640) 79 | if dataset == 'MI2': 80 | model = ShallowConvNetFeatures(2, 15, 26520) 81 | if dataset == 'ERP1': 82 | model = ShallowConvNetFeaturesReduced(2, 16, 6760) 83 | if dataset == 'ERP2': 84 | model = ShallowConvNetFeatures(2, 56, 17160) 85 | elif model_name == 'EEGNet': 86 | if dataset == 'MI1': 87 | model = EEGNet_features(22, 256, 4) 88 | if dataset == 'MI2': 89 | model = EEGNet_features(15, 384, 2) 90 | if dataset == 'ERP1': 91 | model = EEGNet_features(16, 32, 2) 92 | if dataset == 'ERP2': 93 | model = EEGNet_features(56, 256, 2) 94 | 95 | # subject number 96 | k = -1 97 | if dataset == 'MI2': 98 | k = 9 99 | if dataset == 'MI1': 100 | k = 14 101 | if dataset == 'ERP2': 102 | k = 16 103 | if dataset == 'ERP1': 104 | k = 10 105 | 106 | model.to(device) 107 | 108 | opt = torch.optim.Adam(model.parameters(), lr=lr) 109 | lr_scheduler = torch.optim.lr_scheduler.StepLR( 110 | opt, step_size=20, gamma=0.5) 111 | 112 | if not test: 113 | print('start training...') 114 | for iteration in range(1, num_iterations + 1): 115 | 116 | meta_train_error = 0.0 117 | meta_train_accuracy = 0.0 118 | cnt = 0 119 | 120 | for tasks_subj_ind in range(len(l2l_train_tasks_arr)): 121 | for batch in l2l_train_tasks_arr[tasks_subj_ind]: 122 | loss, acc = fast_adapt(model, 123 | batch, 124 | ways, 125 | shots, 126 | metric=pairwise_distances_logits, 127 | device=device) 128 | 129 | cnt += 1 130 | meta_train_error += loss.item() 131 | meta_train_accuracy += acc.item() 132 | 133 | opt.zero_grad() 134 | loss.backward() 135 | opt.step() 136 | 137 | lr_scheduler.step() 138 | 139 | print('Iteration', iteration) 140 | print('Meta Train Error', meta_train_error / cnt) 141 | print('Meta Train Accuracy', meta_train_accuracy / cnt) 142 | 143 | s = dataset + '_test_subj_' + str(test_subj) + '_' + str(model_name) 144 | 145 | if iteration % 50 == 0: 146 | print('saving model...') 147 | 148 | torch.save(model, 149 | './runs/' + str(dataset) + '/protonets_' + s + '_num_iterations_' + str(iteration) + 150 | '_seed' + str(se) + '.pt') 151 | 152 | else: 153 | model = torch.load(path) 154 | model.eval() 155 | 156 | metric_arr = [] 157 | seed_num = 10 158 | for seed in range(seed_num): 159 | 160 | pred_arr = [] 161 | targ_arr = [] 162 | for i, batch in enumerate(test_loader, 1): 163 | predictions, targets = fast_adapt_test(model, 164 | batch, 165 | ways, 166 | shots, 167 | metric=pairwise_distances_logits, 168 | device=device) 169 | 170 | predictions, targets = predictions.tolist(), targets.tolist() 171 | 172 | pred_arr.extend(predictions) 173 | targ_arr.extend(targets) 174 | 175 | if dataset == 'ERP1' or dataset == 'ERP2': 176 | score = roc_auc_score(targets, predictions) 177 | else: 178 | score = accuracy_score(targets, predictions) 179 | metric_arr.append(round(score,5)) 180 | print(metric_arr) 181 | print(round(np.average(metric_arr),5)) 182 | print(round(np.std(metric_arr), 5)) 183 | 184 | 185 | def accuracy(predictions, targets): 186 | predictions = predictions.argmax(dim=1).view(targets.shape) 187 | return (predictions == targets).sum().float() / targets.size(0) 188 | 189 | 190 | def auc_counter(predictions, targets): 191 | predictions = predictions.argmax(dim=1).view(targets.shape) 192 | return predictions.numpy(), targets.numpy() 193 | 194 | 195 | def pairwise_distances_logits(a, b): 196 | n = a.shape[0] 197 | m = b.shape[0] 198 | logits = -((a.unsqueeze(1).expand(n, m, -1) - 199 | b.unsqueeze(0).expand(n, m, -1)) ** 2).sum(dim=2) 200 | return logits 201 | 202 | 203 | def fast_adapt(model, batch, ways, shot, metric=None, device=None): 204 | if metric is None: 205 | metric = pairwise_distances_logits 206 | if device is None: 207 | device = model.device() 208 | data, labels = batch 209 | data = data.to(device) 210 | labels = labels.to(device) 211 | 212 | # Sort data samples by labels 213 | sort = torch.sort(labels) 214 | data = data.squeeze(0)[sort.indices].squeeze(0) 215 | labels = labels.squeeze(0)[sort.indices].squeeze(0) 216 | 217 | # Compute support and query embeddings 218 | embeddings = model(data) 219 | support_indices = np.zeros(data.size(0), dtype=bool) 220 | selection = np.arange(ways) * (shot * 2) 221 | for offset in range(shot): 222 | if (selection + offset)[-1] >= len(support_indices): 223 | return 0,0 224 | support_indices[selection + offset] = True 225 | query_indices = torch.from_numpy(~support_indices) 226 | support_indices = torch.from_numpy(support_indices) 227 | support = embeddings[support_indices] 228 | support = support.reshape(ways, shot, -1).mean(dim=1) 229 | query = embeddings[query_indices] 230 | labels = labels[query_indices].long() 231 | 232 | logits = pairwise_distances_logits(query, support) 233 | loss = torch.nn.functional.cross_entropy(logits, labels) 234 | acc = accuracy(logits, labels) 235 | return loss, acc 236 | 237 | 238 | def fast_adapt_test(model, batch, ways, shot, metric=None, device=None): 239 | if metric is None: 240 | metric = pairwise_distances_logits 241 | if device is None: 242 | device = model.device() 243 | data, labels = batch 244 | data = data.to(device) 245 | labels = labels.to(device) 246 | 247 | # Sort data samples by labels 248 | # TODO: Can this be replaced by ConsecutiveLabels ? 249 | sort = torch.sort(labels) 250 | data = data.squeeze(0)[sort.indices].squeeze(0) 251 | labels = labels.squeeze(0)[sort.indices].squeeze(0) 252 | 253 | # Compute support and query embeddings 254 | embeddings = model(data) 255 | support_indices = np.zeros(data.size(0), dtype=bool) 256 | selection = np.arange(ways) * (shot * 2) 257 | for offset in range(shot): 258 | if (selection + offset)[-1] >= len(support_indices): 259 | return 0,0 260 | support_indices[selection + offset] = True 261 | query_indices = torch.from_numpy(~support_indices) 262 | support = embeddings[support_indices] 263 | support = support.reshape(ways, shot, -1).mean(dim=1) 264 | query = embeddings[query_indices] 265 | labels = labels[query_indices].long() 266 | 267 | logits = metric(query, support) 268 | #loss = F.cross_entropy(logits, labels) 269 | predictions, targets = auc_counter(logits, labels) 270 | 271 | return predictions, targets 272 | 273 | 274 | if __name__ == '__main__': 275 | 276 | lr = 0.001 277 | num_iterations = 100 278 | test = False 279 | for model_name in ['EEGNet', 'ShallowConvNet']: 280 | for dataset in ['MI1', 'MI2', 'ERP1', 'ERP2']: 281 | if dataset == 'MI1': 282 | subj_num = 9 283 | shots = 576 // (2 * 4 * 8) + 1 284 | if test: 285 | shots = 2 # 4 6 286 | elif dataset == 'MI2': 287 | subj_num = 14 288 | shots = 100 // (2 * 2 * 8) + 1 289 | if test: 290 | shots = 2 # 4 6 291 | elif dataset == 'ERP1': 292 | subj_num = 10 293 | shots = 96 // (2 * 8) + 1 294 | if test: 295 | shots = 2 # 4 6 296 | elif dataset == 'ERP2': 297 | subj_num = 16 298 | shots = (24 + 26) // (2 * 8) + 1 # pad to 50 299 | if test: 300 | shots = 2 # 4 6 301 | 302 | if dataset == 'MI1': 303 | ways = 4 304 | else: 305 | ways = 2 306 | 307 | test_load_epoch = 100 308 | for subj in range(0, subj_num): 309 | for se in range(0, 10): 310 | path = './runs/' + str(dataset) + '/protonet_' + str(dataset) + '_test_subj_' + str(subj) + '_' + \ 311 | str(model_name) + '_num_iterations_' + str(test_load_epoch) + '_seed' + str(se) + '.pth' 312 | print('ProtoNet', dataset, model_name) 313 | print('subj', subj, 'seed', se) 314 | main(test_subj=subj, 315 | ways=ways, 316 | shots=shots, 317 | lr=lr, 318 | num_iterations=num_iterations, 319 | cuda=True, 320 | seed=42, 321 | model_name=model_name, 322 | dataset=dataset, 323 | se=se, 324 | test=test, 325 | path=path, 326 | ) 327 | -------------------------------------------------------------------------------- /utils/network.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import torch as tr 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.utils.weight_norm as weightNorm 7 | 8 | from models.EEGNet import EEGNet_features, EEGNet_classifier 9 | 10 | 11 | # dynamic change the weight of the domain-discriminator 12 | def calc_coeff(iter_num, alpha=10.0, max_iter=10000.0): 13 | return np.float(2.0 / (1.0 + np.exp(-alpha * iter_num / max_iter)) - 1) 14 | 15 | 16 | def init_weights(m): 17 | classname = m.__class__.__name__ 18 | if classname.find('BatchNorm') != -1: 19 | nn.init.normal_(m.weight, 1.0, 0.02) 20 | nn.init.zeros_(m.bias) 21 | elif classname.find('Linear') != -1: 22 | nn.init.xavier_normal_(m.weight) 23 | nn.init.zeros_(m.bias) 24 | 25 | 26 | class Net_ln2(nn.Module): 27 | def __init__(self, n_feature, n_hidden, bottleneck_dim): 28 | super(Net_ln2, self).__init__() 29 | self.act = nn.ReLU() 30 | self.fc1 = nn.Linear(n_feature, n_hidden) 31 | self.ln1 = nn.LayerNorm(n_hidden) 32 | self.fc2 = nn.Linear(n_hidden, bottleneck_dim) 33 | self.fc2.apply(init_weights) 34 | self.ln2 = nn.LayerNorm(bottleneck_dim) 35 | 36 | def forward(self, x): 37 | x = self.act(self.ln1(self.fc1(x))) 38 | x = self.act(self.ln2(self.fc2(x))) 39 | x = x.view(x.size(0), -1) 40 | return x 41 | 42 | 43 | class Net_CFE(nn.Module): 44 | def __init__(self, input_dim=310, bottleneck_dim=64): 45 | if input_dim < 256: 46 | print('\nwarning', 'input_dim < 256') 47 | super(Net_CFE, self).__init__() 48 | self.module = nn.Sequential( 49 | nn.Linear(input_dim, 256), 50 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 51 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 52 | nn.Linear(256, 128), 53 | # nn.BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 54 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 55 | nn.Linear(128, bottleneck_dim), # default 64 56 | # nn.BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 57 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 58 | ) 59 | 60 | def forward(self, x): 61 | x = self.module(x) 62 | return x 63 | 64 | 65 | class feat_bottleneck(nn.Module): 66 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"): 67 | super(feat_bottleneck, self).__init__() 68 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.dropout = nn.Dropout(p=0.5) 71 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 72 | self.bottleneck.apply(init_weights) 73 | self.type = type 74 | 75 | def forward(self, x): 76 | x = self.bottleneck(x) 77 | if self.type == "bn": 78 | x = self.bn(x) 79 | return x 80 | 81 | 82 | class feat_classifier(nn.Module): 83 | def __init__(self, class_num, bottleneck_dim, type="linear"): 84 | super(feat_classifier, self).__init__() 85 | self.type = type 86 | if type == 'wn': # 后边换成linear试试 87 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight") 88 | self.fc.apply(init_weights) 89 | else: 90 | self.fc = nn.Linear(bottleneck_dim, class_num) 91 | self.fc.apply(init_weights) 92 | 93 | def forward(self, x): 94 | x = self.fc(x) 95 | return x 96 | 97 | 98 | class feat_classifier_xy(nn.Module): 99 | def __init__(self, class_num, bottleneck_dim, type="linear"): 100 | super(feat_classifier_xy, self).__init__() 101 | self.type = type 102 | if type == 'wn': 103 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight") 104 | self.fc.apply(init_weights) 105 | else: 106 | self.fc = nn.Linear(bottleneck_dim, class_num) 107 | self.fc.apply(init_weights) 108 | 109 | def forward(self, x): 110 | y = self.fc(x) 111 | return x, y 112 | 113 | 114 | def backbone_net(args, a, b, c, return_type='y'): 115 | n_hidden = None 116 | if args.backbone == 'Net_ln2': 117 | netF = Net_ln2(args.input_dim, n_hidden, args.bottleneck).cuda() 118 | 119 | if args.backbone == 'Net_CFE': 120 | netF = Net_CFE(args.input_dim, args.bottleneck).cuda() 121 | 122 | # added 123 | if args.backbone == 'EEGNet': 124 | netF = EEGNet_features(a, b, c) 125 | 126 | if return_type == 'y': 127 | netC = feat_classifier(class_num=args.class_num, bottleneck_dim=args.bottleneck, type=args.layer).cuda() 128 | if return_type == 'xy': 129 | netC = feat_classifier_xy(class_num=args.class_num, bottleneck_dim=args.bottleneck, type=args.layer).cuda() 130 | 131 | # added 132 | if return_type == 'z': 133 | netC = EEGNet_classifier(a, b, c) 134 | 135 | return netF, netC 136 | 137 | 138 | class scalar(nn.Module): 139 | def __init__(self, init_weights): 140 | super(scalar, self).__init__() 141 | self.w = nn.Parameter(tr.tensor(1.) * init_weights) 142 | 143 | def forward(self, x): 144 | x = self.w * tr.ones((x.shape[0]), 1).cuda() 145 | x = tr.sigmoid(x) 146 | return x 147 | 148 | 149 | def grl_hook(coeff): 150 | def fun1(grad): 151 | return -coeff * grad.clone() 152 | 153 | return fun1 154 | 155 | 156 | class Discriminator(nn.Module): 157 | def __init__(self, input_dim=2048, hidden_dim=2048): 158 | super(Discriminator, self).__init__() 159 | self.input_dim = input_dim 160 | self.hidden_dim = hidden_dim 161 | self.ln1 = nn.Linear(input_dim, hidden_dim) 162 | self.bn = nn.BatchNorm1d(hidden_dim) 163 | self.ln2 = nn.Linear(hidden_dim, 1) 164 | 165 | def forward(self, x): 166 | x = F.relu(self.ln1(x)) 167 | x = self.ln2(self.bn(x)) 168 | y = tr.sigmoid(x) 169 | return y 170 | 171 | 172 | class AdversarialNetwork(nn.Module): 173 | def __init__(self, in_feature, hidden_size): 174 | super(AdversarialNetwork, self).__init__() 175 | self.ad_layer1 = nn.Linear(in_feature, hidden_size) 176 | self.ad_layer2 = nn.Linear(hidden_size, hidden_size) 177 | self.ad_layer3 = nn.Linear(hidden_size, 1) 178 | self.relu1 = nn.ReLU() 179 | self.relu2 = nn.ReLU() 180 | self.dropout1 = nn.Dropout(0.5) 181 | self.dropout2 = nn.Dropout(0.5) 182 | self.sigmoid = nn.Sigmoid() 183 | self.apply(init_weights) 184 | self.iter_num = 0 185 | self.alpha = 10 186 | self.max_iter = 10000.0 187 | 188 | def forward(self, x): 189 | if self.training: 190 | self.iter_num += 1 191 | coeff = calc_coeff(self.iter_num, self.alpha, self.max_iter) 192 | x = x * 1.0 193 | x.register_hook(grl_hook(coeff)) 194 | x = self.ad_layer1(x) 195 | x = self.relu1(x) 196 | x = self.dropout1(x) 197 | x = self.ad_layer2(x) 198 | x = self.relu2(x) 199 | x = self.dropout2(x) 200 | y = self.ad_layer3(x) 201 | y = self.sigmoid(y) 202 | return y 203 | 204 | def output_num(self): 205 | return 1 206 | 207 | def get_parameters(self): 208 | return [{"params": self.parameters(), "lr_mult": 10, 'decay_mult': 2}] 209 | 210 | 211 | # =============================================================MSMDA Function=========================================== 212 | class CFE(nn.Module): 213 | def __init__(self, input_dim=310): 214 | if input_dim < 256: 215 | print('\nerr', 'input_dim < 256') 216 | super(CFE, self).__init__() 217 | self.module = nn.Sequential( 218 | nn.Linear(input_dim, 256), 219 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 220 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 221 | nn.Linear(256, 128), 222 | # nn.BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 223 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 224 | nn.Linear(128, 64), 225 | # nn.BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 226 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 227 | ) 228 | 229 | def forward(self, x): 230 | x = self.module(x) 231 | return x 232 | 233 | 234 | class DSFE(nn.Module): 235 | def __init__(self): 236 | super(DSFE, self).__init__() 237 | self.module = nn.Sequential( 238 | nn.Linear(64, 32), 239 | # nn.ReLU(inplace=True), 240 | nn.BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 241 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 242 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 243 | ) 244 | 245 | def forward(self, x): 246 | x = self.module(x) 247 | return x 248 | 249 | 250 | def mmd_linear(f_of_X, f_of_Y): 251 | delta = f_of_X - f_of_Y 252 | loss = tr.mean(tr.mm(delta, tr.transpose(delta, 0, 1))) 253 | return loss 254 | 255 | 256 | class MSMDAERNet(nn.Module): 257 | def __init__(self, backbone_net, num_src=14, num_class=3): 258 | super(MSMDAERNet, self).__init__() 259 | self.sharedNet = backbone_net 260 | # for i in range(1, num_src): 261 | # exec('self.DSFE' + str(i) + '=DSFE()') 262 | # exec('self.cls_fc_DSC' + str(i) + '=nn.Linear(32,' + str(num_class) + ')') 263 | for i in range(num_src): 264 | exec('self.DSFE' + str(i) + '=DSFE()') 265 | exec('self.cls_fc_DSC' + str(i) + '=nn.Linear(32,' + str(num_class) + ')') 266 | 267 | def forward(self, data_src, num_src, data_tgt=0, label_src=0, mark=0): 268 | ''' 269 | description: take one source data and the target data in every forward operation. 270 | the mmd loss is calculated between the source data and the target data (both after the DSFE) 271 | the discrepency loss is calculated between all the classifiers' results (test on the target data) 272 | the cls loss is calculated between the ground truth label and the prediction of the mark-th classifier 273 | 之所以target data每一条线都要过一遍是因为要计算discrepency loss, mmd和cls都只要mark-th那条线就行 274 | param {type}: 275 | mark: int, the order of the current source 276 | data_src: take one source data each time 277 | number_of_source: int 278 | label_Src: corresponding label 279 | data_tgt: target data 280 | return {type} 281 | ''' 282 | mmd_loss = 0 283 | disc_loss = 0 284 | data_tgt_DSFE = [] 285 | if self.training == True: 286 | # common feature extractor 287 | data_src_CFE = self.sharedNet(data_src) 288 | data_tgt_CFE = self.sharedNet(data_tgt) 289 | 290 | # Each domian specific feature extractor 291 | # to extract the domain specific feature of target data 292 | for i in range(num_src): 293 | DSFE_name = 'self.DSFE' + str(i) 294 | data_tgt_DSFE_i = eval(DSFE_name)(data_tgt_CFE) 295 | data_tgt_DSFE.append(data_tgt_DSFE_i) 296 | 297 | # Use the specific feature extractor 298 | # to extract the source data, and calculate the mmd loss 299 | DSFE_name = 'self.DSFE' + str(mark) 300 | data_src_DSFE = eval(DSFE_name)(data_src_CFE) 301 | 302 | # mmd_loss += mmd(data_src_DSFE, data_tgt_DSFE[mark]) 303 | mmd_loss += mmd_linear(data_src_DSFE, data_tgt_DSFE[mark]) 304 | 305 | # discrepency loss 306 | for i in range(len(data_tgt_DSFE)): 307 | if i != mark: 308 | disc_loss += tr.mean(tr.abs( 309 | F.softmax(data_tgt_DSFE[mark], dim=1) - 310 | F.softmax(data_tgt_DSFE[i], dim=1) 311 | )) 312 | 313 | # domain specific classifier and cls_loss 314 | DSC_name = 'self.cls_fc_DSC' + str(mark) 315 | pred_src = eval(DSC_name)(data_src_DSFE) 316 | cls_loss = F.nll_loss(F.log_softmax(pred_src, dim=1), label_src.squeeze()) 317 | 318 | return cls_loss, mmd_loss, disc_loss 319 | 320 | else: 321 | data_CFE = self.sharedNet(data_src) 322 | pred = [] 323 | for i in range(num_src): 324 | DSFE_name = 'self.DSFE' + str(i) 325 | DSC_name = 'self.cls_fc_DSC' + str(i) 326 | feature_DSFE_i = eval(DSFE_name)(data_CFE) 327 | pred.append(eval(DSC_name)(feature_DSFE_i)) 328 | 329 | return pred 330 | -------------------------------------------------------------------------------- /SHOT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.utils.data as Data 6 | import os.path as osp 7 | 8 | from scipy.spatial.distance import cdist 9 | from torch.utils.data import TensorDataset, DataLoader 10 | 11 | from utils import network, loss 12 | from utils.CsvRecord import CsvRecord 13 | from utils.LogRecord import LogRecord 14 | from utils.dataloader import read_mi_test, read_seed_test 15 | from utils.utils import lr_scheduler, fix_random_seed, op_copy, cal_acc, cal_auc_roc 16 | from models.EEGNet import EEGNet, EEGNet_features, EEGNet_classifier 17 | from models.ShallowConvNet import ShallowConvNet, ShallowConvNetFeatures, ShallowConvNetClassifier, ShallowConvNetFeaturesReduced, ShallowConvNetReduced 18 | from EEG_cross_subject_loader import EEG_loader 19 | 20 | import argparse 21 | import time 22 | import os 23 | import random 24 | 25 | 26 | def data_load(X, y, args): 27 | dset_loaders = {} 28 | train_bs = args.batch_size 29 | 30 | sample_idx = torch.from_numpy(np.arange(len(y))).long() 31 | data_tar = Data.TensorDataset(X, y, sample_idx) 32 | 33 | dset_loaders["target"] = Data.DataLoader(data_tar, batch_size=train_bs, shuffle=True) 34 | dset_loaders["Target"] = Data.DataLoader(data_tar, batch_size=train_bs * 3, shuffle=False) 35 | return dset_loaders 36 | 37 | 38 | def train_target(args): 39 | data = EEG_loader(test_subj=args.test_subj, dataset=args.dataset) 40 | 41 | test_x, test_y = data.test_x, data.test_y 42 | X_tar, y_tar = torch.from_numpy(test_x).unsqueeze_(3).to( 43 | torch.float32), torch.squeeze( 44 | torch.from_numpy(test_y), 1).to(torch.long) 45 | dset_loaders = data_load(X_tar, y_tar, args) 46 | 47 | EEGNet_fc_num = {'MI1': 256, 'MI2': 384, 'ERP1': 32, 'ERP2': 256} 48 | ShallowConvNet_fc_num = {'MI1': 16640, 'MI2': 26520, 'ERP1': 6760, 'ERP2': 17160} 49 | 50 | if args.mode == 'baseline': 51 | if args.backbone == 'EEGNet': 52 | #model = EEGNet(chn, EEGNet_fc_num[args.dataset], class_num).to(device) 53 | netF = EEGNet_features(chn, EEGNet_fc_num[args.dataset], class_num).to(device) 54 | netC = EEGNet_classifier(chn, EEGNet_fc_num[args.dataset], class_num).to(device) 55 | elif args.backbone == 'ShallowConvNet': 56 | if args.dataset == 'ERP1': 57 | #model = ShallowConvNetReduced(class_num, chn, ShallowConvNet_fc_num[args.dataset]).to(device) 58 | netF = ShallowConvNetFeaturesReduced(class_num, chn, ShallowConvNet_fc_num[args.dataset]).to(device) 59 | else: 60 | #model = ShallowConvNet(class_num, chn, ShallowConvNet_fc_num[args.dataset]).to(device) 61 | netF = ShallowConvNetFeatures(class_num, chn, ShallowConvNet_fc_num[args.dataset]).to(device) 62 | netC = ShallowConvNetClassifier(class_num, chn, ShallowConvNet_fc_num[args.dataset]).to(device) 63 | model = torch.load(args.path).to(device) 64 | state_dict = model.state_dict() 65 | netF_state_dict = {} 66 | netC_state_dict = {} 67 | for key in state_dict: 68 | if key.startswith('classifier_block') or key.startswith('fc'): 69 | netC_state_dict[key] = state_dict[key] 70 | else: 71 | netF_state_dict[key] = state_dict[key] 72 | 73 | netF.load_state_dict(netF_state_dict) 74 | netC.load_state_dict(netC_state_dict) 75 | 76 | netC.eval() 77 | 78 | for k, v in netC.named_parameters(): 79 | v.requires_grad = False 80 | 81 | param_group = [] 82 | for k, v in netF.named_parameters(): 83 | if args.lr_decay1 > 0: 84 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] 85 | else: 86 | v.requires_grad = False 87 | 88 | optimizer = optim.SGD(param_group) 89 | optimizer = op_copy(optimizer) 90 | 91 | max_iter = args.max_epoch * len(dset_loaders["target"]) # epoch * batch_number 92 | interval_iter = max_iter // args.interval 93 | iter_num = 0 94 | 95 | iter_test = None 96 | while iter_num < max_iter: 97 | try: 98 | inputs_test, _, tar_idx = iter_test.next() 99 | except: 100 | iter_test = iter(dset_loaders["target"]) 101 | inputs_test, _, tar_idx = iter_test.next() 102 | 103 | if inputs_test.size(0) == 1: 104 | continue 105 | 106 | inputs_test = inputs_test.to(args.device) 107 | if iter_num % interval_iter == 0 and args.cls_par > 0: 108 | netF.eval() 109 | mem_label = obtain_label(dset_loaders["Target"], netF, netC, args) 110 | mem_label = torch.from_numpy(mem_label).to(args.device) 111 | netF.train() 112 | 113 | iter_num += 1 114 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 115 | features_test = netF(inputs_test) 116 | outputs_test = netC(features_test) 117 | 118 | # # loss definition 119 | if args.cls_par > 0: 120 | pred = mem_label[tar_idx].long() 121 | # class_weight = torch.tensor([1, 2.42], dtype=torch.float32).cuda() 122 | # criterion = torch.nn.CrossEntropyLoss(weight=class_weight) 123 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) 124 | classifier_loss *= args.cls_par 125 | else: 126 | classifier_loss = torch.tensor(0.0).to(args.device) 127 | 128 | if args.ent: 129 | softmax_out = nn.Softmax(dim=1)(outputs_test) 130 | # criterion = torch.nn.CrossEntropyLoss() 131 | entropy_loss = torch.mean(loss.Entropy(softmax_out)) 132 | if args.gent: 133 | msoftmax = softmax_out.mean(dim=0) 134 | gentropy_loss = torch.sum(msoftmax * torch.log(msoftmax + args.epsilon)) 135 | print('entropy_loss:', round(entropy_loss.item(), 3), 'gentropy_loss:', round(gentropy_loss.item(), 3)) 136 | entropy_loss += gentropy_loss 137 | im_loss = entropy_loss * args.ent_par 138 | print('classifier_loss:', round(classifier_loss.item(), 3), 'im_loss', round(im_loss.item(), 3)) 139 | classifier_loss += im_loss 140 | print('loss', round(classifier_loss.item(), 3)) 141 | print('#' * 20) 142 | optimizer.zero_grad() 143 | classifier_loss.backward() 144 | optimizer.step() 145 | 146 | if iter_num % interval_iter == 0 or iter_num == max_iter: 147 | netF.eval() 148 | if args.dataset == 'MI1' or args.dataset == 'MI2': 149 | acc_t_te = cal_acc(dset_loaders["Target"], netF, netC, args.device) 150 | else: 151 | acc_t_te = cal_auc_roc(dset_loaders["Target"], netF, netC, args.device) 152 | #log_str = 'Task: {}, Iter:{}/{}; Acc = {:.2f}%'.format(args.task_str, iter_num, max_iter, acc_t_te) 153 | #print(log_str) 154 | netF.train() 155 | 156 | if iter_num == max_iter: 157 | print('TL Score = {:.2f}%'.format(acc_t_te)) 158 | return acc_t_te 159 | 160 | 161 | def obtain_label(loader, netF, netC, args): 162 | start_test = True 163 | with torch.no_grad(): 164 | iter_test = iter(loader) 165 | for _ in range(len(loader)): 166 | data = iter_test.next() 167 | inputs = data[0] 168 | #labels = data[1] 169 | inputs = inputs.to(args.device) 170 | feas = netF(inputs) 171 | outputs = netC(feas) 172 | if start_test: 173 | all_fea = feas.float().cpu() 174 | all_output = outputs.float().cpu() 175 | #all_label = labels.float() 176 | start_test = False 177 | else: 178 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 179 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 180 | #all_label = torch.cat((all_label, labels.float()), 0) 181 | 182 | all_output = nn.Softmax(dim=1)(all_output) 183 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) 184 | unknown_weight = 1 - ent / np.log(args.class_num) 185 | _, predict = torch.max(all_output, 1) 186 | 187 | #accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 188 | if args.distance == 'cosine': 189 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 190 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 191 | 192 | all_fea = all_fea.float().cpu().numpy() 193 | K = all_output.size(1) 194 | aff = all_output.float().cpu().numpy() 195 | initc = aff.transpose().dot(all_fea) 196 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 197 | cls_count = np.eye(K)[predict].sum(axis=0) 198 | labelset = np.where(cls_count > args.threshold) 199 | labelset = labelset[0] 200 | 201 | dd = cdist(all_fea, initc[labelset], args.distance) 202 | pred_label = dd.argmin(axis=1) 203 | pred_label = labelset[pred_label] 204 | 205 | for round in range(1): # SSL 206 | aff = np.eye(K)[pred_label] 207 | initc = aff.transpose().dot(all_fea) 208 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 209 | dd = cdist(all_fea, initc[labelset], args.distance) 210 | pred_label = dd.argmin(axis=1) 211 | pred_label = labelset[pred_label] 212 | 213 | #acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 214 | 215 | return pred_label.astype('int') 216 | 217 | 218 | if __name__ == '__main__': 219 | 220 | mode = 'baseline' 221 | # mode = 'MDMAML' 222 | #for model_name in ['EEGNet', 'ShallowConvNet']: 223 | for model_name in ['EEGNet']: 224 | #for dataset in ['MI1', 'MI2', 'ERP1', 'ERP2']: 225 | for dataset in ['MI1', 'MI2']: 226 | print('SHOT', model_name, dataset) 227 | 228 | if dataset == 'MI1': 229 | subj_num = 9 230 | elif dataset == 'MI2': 231 | subj_num = 14 232 | elif dataset == 'ERP1': 233 | subj_num = 10 234 | elif dataset == 'ERP2': 235 | subj_num = 16 236 | 237 | if dataset == 'MI1': 238 | ways = 4 239 | else: 240 | ways = 2 241 | 242 | total_acc_arr = [] 243 | total_std_arr = [] 244 | 245 | for test_subj in range(0, subj_num): 246 | sub_acc_all = [] 247 | for se in range(0, 10): 248 | print('Test Subject', test_subj, 'Seed', se) 249 | 250 | if dataset == 'MI1': 251 | chn, class_num, trial_num = 22, 4, 576 252 | if dataset == 'MI2': 253 | chn, class_num, trial_num = 15, 2, 100 254 | if dataset == 'ERP1': 255 | chn, class_num, trial_num = 16, 2, 575 256 | if dataset == 'ERP2': 257 | chn, class_num, trial_num = 56, 2, 340 258 | 259 | args = argparse.Namespace(lr=0.01, lr_decay1=0.1, lr_decay2=1.0, ent=True, 260 | gent=True, cls_par=0.3, ent_par=1.0, epsilon=1e-05, layer='wn', 261 | interval=15, 262 | chn=chn, class_num=class_num, cov_type='oas', trial=trial_num, 263 | threshold=0, distance='cosine') 264 | 265 | args.seed = 42 266 | random.seed(args.seed) 267 | np.random.seed(args.seed) 268 | torch.manual_seed(args.seed) 269 | 270 | device = torch.device('cpu') 271 | args.cuda = True 272 | if args.cuda: 273 | torch.cuda.manual_seed(args.seed) 274 | torch.backends.cudnn.deterministic = True 275 | device = torch.device('cuda:0') 276 | 277 | args.test_subj = test_subj 278 | args.pretrain_seed = se 279 | args.backbone = model_name 280 | args.dataset = dataset 281 | args.method = 'shot' 282 | args.device = device 283 | args.batch_size = 8 284 | args.max_epoch = 10 285 | args.output_src = './runs/' + args.dataset + '/' 286 | args.mode = mode 287 | 288 | if args.mode == 'baseline': 289 | path = './runs/' + str(args.dataset) + '/baseline_' + str(args.backbone) + str( 290 | args.dataset) + '_seed' + str( 291 | args.pretrain_seed) + '_test_subj_' + str(args.test_subj) + '_epoch100.pt' 292 | args.path = path 293 | elif args.mode == 'MDMAML': 294 | path1 = './runs/' + str(args.dataset) + '/mdmaml_model1_' + str( 295 | args.dataset) + '_test_subj_' + \ 296 | str(args.test_subj) + dataset + '_test_subj_' + str(test_subj) + \ 297 | '_shots_1_meta_lr_0.001_fast_lr_0.001_meta_batch_size_1_adaptation_steps_1' \ 298 | + str(model_name) + '_num_iterations_500seed' + str(se) + '.pt' 299 | path2 = './runs/' + str(args.dataset) + '/mdmaml_model2_' + str( 300 | args.dataset) + '_test_subj_' + \ 301 | str(args.test_subj) + dataset + '_test_subj_' + str(test_subj) + \ 302 | '_shots_1_meta_lr_0.001_fast_lr_0.001_meta_batch_size_1_adaptation_steps_1' \ 303 | + str(model_name) + '_num_iterations_500seed' + str(se) + '.pt' 304 | args.paths = [path1, path2] 305 | 306 | #source_str = 'Except_S' + str(test_subj) 307 | #target_str = 'S' + str(test_subj) 308 | #info_str = '\n========================== Transfer to ' + target_str + ' ==========================' 309 | #print(info_str) 310 | 311 | #args.task_str = source_str + '_' + target_str 312 | #args.output_dir_src = osp.join(args.output_src, source_str) 313 | 314 | sub_acc_all.append(train_target(args)) 315 | 316 | print('Sub acc: ', np.round(sub_acc_all, 5)) 317 | print('Avg acc: ', np.round(np.mean(sub_acc_all), 5)) 318 | total_acc_arr.append(np.round(np.mean(sub_acc_all), 5)) 319 | total_std_arr.append(np.round(np.std(sub_acc_all), 5)) 320 | 321 | print(str(dataset) + ' SHOT') 322 | print('avg_arr: ', total_acc_arr) 323 | print('std_arr: ', total_std_arr) 324 | print('total_avg: ', np.round(np.mean(total_acc_arr), 5)) 325 | print('total_std: ', np.round(np.std(total_acc_arr), 5)) 326 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import torch as tr 4 | import torch.nn as nn 5 | import math 6 | import torch.nn.functional as F 7 | from torch.autograd import Function 8 | from typing import Optional, Sequence 9 | 10 | 11 | def Entropy(input_): 12 | epsilon = 1e-5 13 | entropy = -input_ * tr.log(input_ + epsilon) 14 | entropy = tr.sum(entropy, dim=1) 15 | return entropy 16 | 17 | 18 | class CELabelSmooth(nn.Module): 19 | """Cross entropy loss with label smoothing regularizer. 20 | Reference: 21 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 22 | Equation: y = (1 - epsilon) * y + epsilon / K. 23 | Args: 24 | num_classes (int): number of classes. 25 | epsilon (float): weight. 26 | """ 27 | 28 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 29 | super(CELabelSmooth, self).__init__() 30 | self.num_classes = num_classes 31 | self.epsilon = epsilon 32 | self.use_gpu = use_gpu 33 | self.logsoftmax = nn.LogSoftmax(dim=1) 34 | self.reduction = reduction 35 | 36 | def forward(self, inputs, targets): 37 | """ 38 | Args: 39 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 40 | targets: ground truth labels with shape (num_classes) 41 | """ 42 | log_probs = self.logsoftmax(inputs) 43 | 44 | # 加入mixup之后,原始标签已经是one hot的形式,这里不需要再变换 45 | # targets = tr.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 46 | if self.use_gpu: targets = targets.cuda() 47 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 48 | loss = (- targets * log_probs).sum(dim=1) 49 | if self.reduction: 50 | return loss.mean() 51 | else: 52 | return loss 53 | 54 | 55 | class CELabelSmooth_raw(nn.Module): 56 | """Cross entropy loss with label smoothing regularizer. 57 | Reference: 58 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 59 | Equation: y = (1 - epsilon) * y + epsilon / K. 60 | Args: 61 | num_classes (int): number of classes. 62 | epsilon (float): weight. 63 | """ 64 | 65 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 66 | super(CELabelSmooth_raw, self).__init__() 67 | self.num_classes = num_classes 68 | self.epsilon = epsilon 69 | self.use_gpu = use_gpu 70 | self.logsoftmax = nn.LogSoftmax(dim=1) 71 | self.reduction = reduction 72 | 73 | def forward(self, inputs, targets): 74 | """ 75 | Args: 76 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 77 | targets: ground truth labels with shape (num_classes) 78 | """ 79 | log_probs = self.logsoftmax(inputs) 80 | targets = tr.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 81 | if self.use_gpu: targets = targets.cuda() 82 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 83 | loss = (- targets * log_probs).sum(dim=1) 84 | if self.reduction: 85 | return loss.mean() 86 | else: 87 | return loss 88 | 89 | 90 | class KnowledgeDistillationLoss(nn.Module): 91 | def __init__(self, reduction='mean', alpha=-1.): 92 | super().__init__() 93 | self.reduction = reduction 94 | self.alpha = alpha 95 | 96 | def forward(self, inputs, targets, mask=None): 97 | inputs = inputs.narrow(1, 0, targets.shape[1]) 98 | outputs = tr.log_softmax(inputs, dim=1) 99 | labels = tr.softmax(targets * self.alpha, dim=1) 100 | 101 | loss = (outputs * labels).mean(dim=1) 102 | if mask is not None: 103 | loss = loss * mask.float() 104 | 105 | if self.reduction == 'mean': 106 | outputs = -tr.mean(loss) 107 | elif self.reduction == 'sum': 108 | outputs = -tr.sum(loss) 109 | else: 110 | outputs = -loss 111 | 112 | return outputs 113 | 114 | 115 | class ConsistencyLoss(nn.Module): 116 | """ 117 | Label consistency loss. 118 | """ 119 | 120 | def __init__(self, num_select=2): 121 | super(ConsistencyLoss, self).__init__() 122 | self.num_select = num_select 123 | 124 | def forward(self, prob): 125 | dl = 0. 126 | count = 0 127 | for i in range(prob.shape[1] - 1): 128 | for j in range(i + 1, prob.shape[1]): 129 | dl += self.jensen_shanon(prob[:, i, :], prob[:, j, :], dim=1) 130 | count += 1 131 | return dl / count 132 | 133 | @staticmethod 134 | def jensen_shanon(pred1, pred2, dim): 135 | """ 136 | Jensen-Shannon Divergence. 137 | """ 138 | m = (tr.softmax(pred1, dim=dim) + tr.softmax(pred2, dim=dim)) / 2 139 | pred1 = F.log_softmax(pred1, dim=dim) 140 | pred2 = F.log_softmax(pred2, dim=dim) 141 | return (F.kl_div(pred1, m.detach(), reduction='batchmean') + F.kl_div(pred2, m.detach(), 142 | reduction='batchmean')) / 2 143 | 144 | 145 | class source_inconsistency_loss(nn.Module): 146 | # source models inconsistency loss. 147 | def __init__(self, th_max=0.1): 148 | super(source_inconsistency_loss, self).__init__() 149 | self.th_max = th_max 150 | 151 | # 计算不同models在每个样本预测概率的每个类别上的方差,要求 152 | def forward(self, prob): # [4, 8, 4] 153 | si_std = tr.std(prob, dim=1).mean(dim=1).mean(dim=0) 154 | return si_std 155 | 156 | 157 | class BatchEntropyLoss(nn.Module): 158 | """ 159 | Batch-entropy loss. 160 | 要求各源模型预测的各类别平均entropy的平均值要小, 161 | 而DECISION里面是加权之后的标签上各类别平均的entropy要小 162 | 差异就是计算entropy先还是算类别均值先 163 | """ 164 | 165 | def __init__(self): 166 | super(BatchEntropyLoss, self).__init__() 167 | 168 | def forward(self, prob): # prob: [4, 8, 4] 169 | batch_entropy = F.softmax(prob, dim=2).mean(dim=0) 170 | batch_entropy = batch_entropy * (-batch_entropy.log()) 171 | batch_entropy = -batch_entropy.sum(dim=1) 172 | loss = batch_entropy.mean() 173 | return loss, batch_entropy 174 | 175 | 176 | class InstanceEntropyLoss(nn.Module): 177 | """ 178 | Instance-entropy loss. 179 | """ 180 | 181 | def __init__(self): 182 | super(InstanceEntropyLoss, self).__init__() 183 | 184 | def forward(self, prob): 185 | instance_entropy = F.softmax(prob, dim=2) * F.log_softmax(prob, dim=2) 186 | instance_entropy = -1.0 * instance_entropy.sum(dim=2) 187 | instance_entropy = instance_entropy.mean(dim=0) 188 | loss = instance_entropy.mean() 189 | return loss, instance_entropy 190 | 191 | 192 | class InformationMaximizationLoss(nn.Module): 193 | """ 194 | Information maximization loss. 195 | """ 196 | 197 | def __init__(self): 198 | super(InformationMaximizationLoss, self).__init__() 199 | 200 | def forward(self, pred_prob, epsilon): 201 | softmax_out = nn.Softmax(dim=1)(pred_prob) 202 | ins_entropy_loss = tr.mean(Entropy(softmax_out)) 203 | msoftmax = softmax_out.mean(dim=0) 204 | class_entropy_loss = tr.sum(-msoftmax * tr.log(msoftmax + epsilon)) 205 | im_loss = ins_entropy_loss - class_entropy_loss 206 | 207 | return im_loss 208 | 209 | 210 | # =============================================================DAN Function============================================= 211 | class MultipleKernelMaximumMeanDiscrepancy(nn.Module): 212 | r""" 213 | Args: 214 | kernels (tuple(tr.nn.Module)): kernel functions. 215 | linear (bool): whether use the linear version of DAN. Default: False 216 | 217 | Inputs: 218 | - z_s (tensor): activations from the source domain, :math:`z^s` 219 | - z_t (tensor): activations from the target domain, :math:`z^t` 220 | """ 221 | 222 | def __init__(self, kernels: Sequence[nn.Module], linear: Optional[bool] = False): 223 | super(MultipleKernelMaximumMeanDiscrepancy, self).__init__() 224 | self.kernels = kernels 225 | self.index_matrix = None 226 | self.linear = linear 227 | 228 | def forward(self, z_s: tr.Tensor, z_t: tr.Tensor) -> tr.Tensor: 229 | features = tr.cat([z_s, z_t], dim=0) 230 | batch_size = int(z_s.size(0)) 231 | self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s.device) 232 | 233 | kernel_matrix = sum([kernel(features) for kernel in self.kernels]) # Add up the matrix of each kernel 234 | # Add 2 / (n-1) to make up for the value on the diagonal 235 | # to ensure loss is positive in the non-linear version 236 | loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1) 237 | 238 | return loss 239 | 240 | 241 | def _update_index_matrix(batch_size: int, index_matrix: Optional[tr.Tensor] = None, 242 | linear: Optional[bool] = True) -> tr.Tensor: 243 | r""" 244 | Update the `index_matrix` which convert `kernel_matrix` to loss. 245 | If `index_matrix` is a tensor with shape (2 x batch_size, 2 x batch_size), then return `index_matrix`. 246 | Else return a new tensor with shape (2 x batch_size, 2 x batch_size). 247 | """ 248 | if index_matrix is None or index_matrix.size(0) != batch_size * 2: 249 | index_matrix = tr.zeros(2 * batch_size, 2 * batch_size) 250 | if linear: 251 | for i in range(batch_size): 252 | s1, s2 = i, (i + 1) % batch_size 253 | t1, t2 = s1 + batch_size, s2 + batch_size 254 | index_matrix[s1, s2] = 1. / float(batch_size) 255 | index_matrix[t1, t2] = 1. / float(batch_size) 256 | index_matrix[s1, t2] = -1. / float(batch_size) 257 | index_matrix[s2, t1] = -1. / float(batch_size) 258 | else: 259 | for i in range(batch_size): 260 | for j in range(batch_size): 261 | if i != j: 262 | index_matrix[i][j] = 1. / float(batch_size * (batch_size - 1)) 263 | index_matrix[i + batch_size][j + batch_size] = 1. / float(batch_size * (batch_size - 1)) 264 | for i in range(batch_size): 265 | for j in range(batch_size): 266 | index_matrix[i][j + batch_size] = -1. / float(batch_size * batch_size) 267 | index_matrix[i + batch_size][j] = -1. / float(batch_size * batch_size) 268 | return index_matrix 269 | 270 | 271 | class GaussianKernel(nn.Module): 272 | r"""Gaussian Kernel Matrix 273 | Args: 274 | sigma (float, optional): bandwidth :math:`\sigma`. Default: None 275 | track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`. 276 | Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True`` 277 | alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True`` 278 | 279 | Inputs: 280 | - X (tensor): input group :math:`X` 281 | 282 | Shape: 283 | - Inputs: :math:`(minibatch, F)` where F means the dimension of input features. 284 | - Outputs: :math:`(minibatch, minibatch)` 285 | """ 286 | 287 | def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True, 288 | alpha: Optional[float] = 1.): 289 | super(GaussianKernel, self).__init__() 290 | assert track_running_stats or sigma is not None 291 | self.sigma_square = tr.tensor(sigma * sigma) if sigma is not None else None 292 | self.track_running_stats = track_running_stats 293 | self.alpha = alpha 294 | 295 | def forward(self, X: tr.Tensor) -> tr.Tensor: 296 | l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2) 297 | 298 | if self.track_running_stats: 299 | self.sigma_square = self.alpha * tr.mean(l2_distance_square.detach()) 300 | 301 | return tr.exp(-l2_distance_square / (2 * self.sigma_square)) 302 | 303 | 304 | # =============================================================CDANE Function=========================================== 305 | def CDANE(input_list, ad_net, entropy=None, coeff=None, random_layer=None): 306 | softmax_output = input_list[1].detach() 307 | feature = input_list[0] 308 | if random_layer is None: 309 | # print('None') 310 | op_out = tr.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1)) 311 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1))) 312 | else: 313 | random_out = random_layer.forward([feature, softmax_output]) 314 | ad_out = ad_net(random_out.view(-1, random_out.size(1))) 315 | batch_size = softmax_output.size(0) // 2 316 | dc_target = tr.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 317 | if entropy is not None: 318 | entropy.register_hook(grl_hook(coeff)) 319 | entropy = 1.0 + tr.exp(-entropy) 320 | source_mask = tr.ones_like(entropy) 321 | source_mask[feature.size(0) // 2:] = 0 322 | source_weight = entropy * source_mask 323 | target_mask = tr.ones_like(entropy) 324 | target_mask[0:feature.size(0) // 2] = 0 325 | target_weight = entropy * target_mask 326 | weight = source_weight / tr.sum(source_weight).detach().item() + \ 327 | target_weight / tr.sum(target_weight).detach().item() 328 | return tr.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / tr.sum( 329 | weight).detach().item() 330 | else: 331 | return nn.BCELoss()(ad_out, dc_target) 332 | 333 | 334 | def grl_hook(coeff): 335 | def fun1(grad): 336 | return -coeff * grad.clone() 337 | 338 | return fun1 339 | 340 | 341 | class ReverseLayerF(Function): 342 | 343 | @staticmethod 344 | def forward(ctx, x, alpha=1): 345 | ctx.alpha = alpha 346 | return x.view_as(x) 347 | 348 | @staticmethod 349 | def backward(ctx, grad_output): 350 | output = grad_output.neg() * ctx.alpha 351 | return output, None 352 | 353 | 354 | class RandomLayer(nn.Module): 355 | def __init__(self, input_dim_list=[], output_dim=1024): 356 | super(RandomLayer, self).__init__() 357 | self.input_num = len(input_dim_list) 358 | self.output_dim = output_dim 359 | self.random_matrix = [tr.randn(input_dim_list[i], output_dim) for i in range(self.input_num)] 360 | 361 | def forward(self, input_list): 362 | return_list = [tr.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)] 363 | return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list)) 364 | for single in return_list[1:]: 365 | return_tensor = tr.mul(return_tensor, single) 366 | return return_tensor 367 | 368 | def cuda(self): 369 | super(RandomLayer, self).cuda() 370 | self.random_matrix = [val.cuda() for val in self.random_matrix] 371 | 372 | 373 | # =============================================================MCC Function============================================= 374 | class ClassConfusionLoss(nn.Module): 375 | """ 376 | The class confusion loss 377 | 378 | Parameters: 379 | - **t** Optional(float): the temperature factor used in MCC 380 | """ 381 | 382 | def __init__(self, t): 383 | super(ClassConfusionLoss, self).__init__() 384 | self.t = t 385 | 386 | def forward(self, output: tr.Tensor) -> tr.Tensor: 387 | n_sample, n_class = output.shape 388 | softmax_out = nn.Softmax(dim=1)(output / self.t) 389 | entropy_weight = Entropy(softmax_out).detach() 390 | entropy_weight = 1 + tr.exp(-entropy_weight) 391 | entropy_weight = (n_sample * entropy_weight / tr.sum(entropy_weight)).unsqueeze(dim=1) 392 | class_confusion_matrix = tr.mm((softmax_out * entropy_weight).transpose(1, 0), softmax_out) 393 | class_confusion_matrix = class_confusion_matrix / tr.sum(class_confusion_matrix, dim=1) 394 | mcc_loss = (tr.sum(class_confusion_matrix) - tr.trace(class_confusion_matrix)) / n_class 395 | return mcc_loss 396 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/12/18 11:04 3 | # @Author : wenzhang 4 | # @File : dataloader.py 5 | import torch as tr 6 | import numpy as np 7 | from sklearn import preprocessing 8 | from torch.autograd import Variable 9 | from pyriemann.estimation import Covariances 10 | from pyriemann.tangentspace import TangentSpace 11 | from imblearn.over_sampling import SMOTE 12 | from utils.data_augment import data_aug 13 | from scipy.io import loadmat, savemat 14 | from os import walk 15 | 16 | 17 | def read_mi_all(args): 18 | # (9, 288, 22, 750) (9, 288) 19 | if args.data_env == 'local': 20 | file = '/Users/wenz/dataset/MOABB/' + args.data + '.npz' 21 | if args.data_env == 'gpu': 22 | file = '/mnt/ssd2/wzw/data/bci/' + args.data + '.npz' 23 | 24 | MI = np.load(file) 25 | Data_raw, Label = MI['data'], MI['label'] 26 | 27 | data, label = [], [] 28 | for s in range(args.N): 29 | # each source sub 30 | src_data = np.squeeze(Data_raw[s, :, :, :]) 31 | src_label = Label[s, :].reshape(-1, 1) 32 | 33 | if args.aug: 34 | sample_size = src_data.shape[2] 35 | # mult_flag, noise_flag, neg_flag, freq_mod_flag 36 | flag_aug = [True, True, True, True] 37 | src_data = np.transpose(src_data, (0, 2, 1)) 38 | src_data, src_label = data_aug(src_data, src_label, sample_size, flag_aug) 39 | src_data = np.transpose(src_data, (0, 2, 1)) 40 | 41 | covar = Covariances(estimator=args.cov_type).transform(src_data) 42 | fea_tsm = TangentSpace().fit_transform(covar) 43 | src_label = src_label.reshape(-1, 1) 44 | 45 | data.append(fea_tsm) 46 | label.append(src_label) 47 | 48 | return data, label 49 | 50 | 51 | def read_mi_train(args): 52 | # (9, 288, 22, 750) (9, 288) 53 | if args.data_env == 'local': 54 | file = '/Users/wenz/dataset/MOABB/' + args.data + '.npz' 55 | if args.data_env == 'gpu': 56 | file = '/mnt/ssd2/wzw/data/bci/' + args.data + '.npz' 57 | 58 | MI = np.load(file) 59 | Data_raw, Label = MI['data'], MI['label'] 60 | 61 | # source sub 62 | src_data = np.squeeze(Data_raw[args.ids, :, :, :]) 63 | src_label = np.squeeze(Label[args.ids, :]) 64 | src_label = tr.from_numpy(src_label).long() 65 | print(src_data.shape, src_label.shape) # (288, 22, 750) 66 | 67 | if args.aug: 68 | sample_size = src_data.shape[2] 69 | # mult_flag, noise_flag, neg_flag, freq_mod_flag 70 | flag_aug = [True, True, True, True] 71 | # flag_aug = [True, False, False, False] 72 | src_data = np.transpose(src_data, (0, 2, 1)) 73 | src_data, src_label = data_aug(src_data, src_label, sample_size, flag_aug) 74 | src_data = np.transpose(src_data, (0, 2, 1)) 75 | src_label = tr.from_numpy(src_label).long() 76 | # print(src_data.shape, src_label.shape) # (288*7, 22, 750) 77 | 78 | covar = Covariances(estimator=args.cov_type).transform(src_data) 79 | fea_tsm = TangentSpace().fit_transform(covar) 80 | fea_tsm = Variable(tr.from_numpy(fea_tsm).float()) 81 | 82 | # X.shape - (#samples, # feas) 83 | print(fea_tsm.shape, src_label.shape) 84 | 85 | return fea_tsm, src_label 86 | 87 | 88 | def read_mi_test(args): 89 | # (9, 288, 22, 750) (9, 288) 90 | if args.data_env == 'local': 91 | file = '/Users/wenz/dataset/MOABB/' + args.data + '.npz' 92 | if args.data_env == 'gpu': 93 | file = '/mnt/ssd2/wzw/data/bci/' + args.data + '.npz' 94 | 95 | MI = np.load(file) 96 | Data_raw, Label = MI['data'], MI['label'] 97 | 98 | # target sub 99 | tar_data = np.squeeze(Data_raw[args.idt, :, :, :]) 100 | tar_label = np.squeeze(Label[args.idt, :]) 101 | tar_label = tr.from_numpy(tar_label).long() 102 | 103 | # 288 * 22 * 750 104 | covar_src = Covariances(estimator=args.cov_type).transform(tar_data) 105 | fea_tsm = TangentSpace().fit_transform(covar_src) 106 | 107 | # covar = Covariances(estimator=cov_type).transform(tar_data) 108 | # tmp_ref = TangentSpace().fit(covar[:ntu, :, :]) 109 | # fea_tsm = tmp_ref.transform(covar) 110 | 111 | fea_tsm = Variable(tr.from_numpy(fea_tsm).float()) 112 | 113 | # X.shape - (#samples, # feas) 114 | print(fea_tsm.shape, tar_label.shape) 115 | return fea_tsm, tar_label 116 | 117 | 118 | def read_mi_test_aug(args): 119 | # (9, 288, 22, 750) (9, 288) 120 | if args.data_env == 'local': 121 | file = '/Users/wenz/dataset/MOABB/' + args.data + '.npz' 122 | if args.data_env == 'gpu': 123 | file = '/mnt/ssd2/wzw/data/bci/' + args.data + '.npz' 124 | 125 | MI = np.load(file) 126 | Data_raw, Label = MI['data'], MI['label'] 127 | 128 | # target sub 129 | tar_data = np.squeeze(Data_raw[args.idt, :, :, :]) 130 | tar_label = np.squeeze(Label[args.idt, :]) 131 | 132 | # 288 * 22 * 750 133 | covar_tar = Covariances(estimator=args.cov_type).transform(tar_data) 134 | X_tar = TangentSpace().fit_transform(covar_tar) 135 | X_tar = Variable(tr.from_numpy(X_tar).float()) 136 | y_tar = tr.from_numpy(tar_label).long() 137 | 138 | sample_size = tar_data.shape[2] 139 | flag_aug = [True, True, True, True] 140 | tar_data_tmp = np.transpose(tar_data, (0, 2, 1)) 141 | tar_data_tmp, tar_label_aug = data_aug(tar_data_tmp, tar_label, sample_size, flag_aug) 142 | tar_data_aug = np.transpose(tar_data_tmp, (0, 2, 1)) 143 | 144 | # 288 * 22 * 750 145 | covar_tar = Covariances(estimator=args.cov_type).transform(tar_data_aug) 146 | X_tar_aug = TangentSpace().fit_transform(covar_tar) 147 | X_tar_aug = Variable(tr.from_numpy(X_tar_aug).float()) 148 | y_tar_aug = tr.from_numpy(tar_label_aug).long() 149 | 150 | # X.shape - (#samples, # feas) 151 | print(y_tar.shape, y_tar.shape) 152 | print(X_tar_aug.shape, y_tar_aug.shape) 153 | return X_tar, y_tar, X_tar_aug, y_tar_aug 154 | 155 | 156 | def read_mi_combine(args): # no data augment 157 | # (9, 288, 22, 750) (9, 288) 158 | if args.data_env == 'local': 159 | file = '/Users/wenz/dataset/MOABB/' + args.data + '.npz' 160 | if args.data_env == 'gpu': 161 | file = '/mnt/ssd2/wzw/data/bci/' + args.data + '.npz' 162 | 163 | MI = np.load(file) 164 | Data_raw, Label = MI['data'], MI['label'] 165 | # print('raw data shape', Data_raw.shape, Label.shape) 166 | 167 | Data_new = Data_raw.copy() 168 | n_sub = len(Data_raw) 169 | 170 | # MTS transfer 171 | ids = np.delete(np.arange(0, n_sub), args.idt) 172 | src_data, src_label = [], [] 173 | for i in range(n_sub - 1): 174 | src_data.append(np.squeeze(Data_new[ids[i]])) 175 | src_label.append(np.squeeze(Label[ids[i]])) 176 | src_data = np.concatenate(src_data, axis=0) 177 | src_label = np.concatenate(src_label, axis=0) 178 | 179 | # final label 180 | src_label = np.squeeze(src_label) 181 | src_label = tr.from_numpy(src_label).long() 182 | print(src_data.shape, src_label.shape) 183 | 184 | # final features 185 | covar = Covariances(estimator=args.cov_type).transform(src_data) 186 | fea_tsm = TangentSpace().fit_transform(covar) 187 | src_data = Variable(tr.from_numpy(fea_tsm).float()) 188 | 189 | return src_data, src_label 190 | 191 | 192 | def read_mi_combine_tar(args): # no data augment 193 | # (9, 288, 22, 750) (9, 288) 194 | if args.data_env == 'local': 195 | file = '/Users/wenz/dataset/MOABB/' + args.data + '.npz' 196 | if args.data_env == 'gpu': 197 | file = '/mnt/ssd2/wzw/data/bci/' + args.data + '.npz' 198 | 199 | MI = np.load(file) 200 | Data_raw, Label = MI['data'], MI['label'] 201 | print('raw data shape', Data_raw.shape, Label.shape) 202 | 203 | Data_new = Data_raw.copy() 204 | n_sub = len(Data_raw) 205 | 206 | # combine multiple source data 207 | ids = np.delete(np.arange(0, n_sub), args.idt) 208 | src_data, src_label = [], [] 209 | for i in range(n_sub - 1): 210 | src_data.append(np.squeeze(Data_new[ids[i]])) 211 | src_label.append(np.squeeze(Label[ids[i]])) 212 | src_data = np.concatenate(src_data, axis=0) 213 | src_label = np.concatenate(src_label, axis=0) 214 | 215 | # final label 216 | src_label = np.squeeze(src_label) 217 | src_label = tr.from_numpy(src_label).long() 218 | print(src_data.shape, src_label.shape) # (n_src, chns, fts) 219 | covar = Covariances(estimator=args.cov_type).transform(src_data) 220 | fea_tsm = TangentSpace().fit_transform(covar) # tangent space transform 221 | src_data = Variable(tr.from_numpy(fea_tsm).float()) # (n_src, low_dim_fts) 222 | 223 | # single target domain data 224 | tar_data = np.squeeze(Data_new[args.idt, :, :, :]) 225 | tar_label = np.squeeze(Label[args.idt, :]) 226 | print(tar_data.shape, tar_label.shape) # (n_tar, chns, fts) 227 | covar = Covariances(estimator=args.cov_type).transform(tar_data) 228 | tmp_ref = TangentSpace().fit(covar) # tangent space transform 229 | fea_tsm = tmp_ref.transform(covar) 230 | tar_data = Variable(tr.from_numpy(fea_tsm).float()) 231 | tar_label = tr.from_numpy(tar_label).long() 232 | print(src_data.shape, src_label.shape) # (n_src, low_dim_fts) 233 | print(tar_data.shape, tar_label.shape) # (n_tar, low_dim_fts) 234 | 235 | return src_data, src_label, tar_data, tar_label 236 | 237 | 238 | def read_seizure_combine_tar(args): # no data augment 239 | # (9, 288, 22, 750) (9, 288) 240 | if args.data_env == 'local': 241 | file = '../data/fts_labels/' + args.data + '.npz' 242 | if args.data_env == 'gpu': 243 | domains = next(walk('/home/zwwang/code/Source_combined/data/fts_labels/'), (None, None, []))[2] 244 | i = args.idt 245 | src_x = [] 246 | src_y = [] 247 | for j in range(len(domains)): 248 | if i != j: 249 | src = loadmat('/home/zwwang/code/Source_combined/data/fts_labels/' + domains[j]) 250 | src0, src1 = src['data'], src['label'] 251 | src_x.append(src0) 252 | src_y.append(src1) 253 | src_data = np.concatenate(src_x, axis=0) 254 | src_label = np.concatenate(src_y, axis=1).squeeze() 255 | # print(src_data.shape, src_label.shape) # (n_src, chns, fts) 256 | # load target domain 257 | tar = loadmat('/home/zwwang/code/Source_combined/data/fts_labels/' + domains[i]) 258 | tar_data, tar_label = tar['data'], tar['label'] 259 | tar_label = tar_label.squeeze() 260 | # print(tar_data.shape, tar_label.shape) # (n_tar, chns, fts) 261 | 262 | # smooth class imbalance 263 | oversample = SMOTE(random_state=42) 264 | src_data, src_label = oversample.fit_resample(src_data, src_label) 265 | src_data = data_normalize(src_data, args.norm) 266 | src_data = Variable(tr.from_numpy(src_data).float()) 267 | src_label = tr.from_numpy(src_label).long() 268 | print(src_data.shape, src_label.shape) # (n_src, chns, fts) 269 | 270 | # single target domain data 271 | tar_data = data_normalize(tar_data, args.norm) 272 | tar_data = Variable(tr.from_numpy(tar_data).float()) 273 | tar_label = tr.from_numpy(tar_label).long() 274 | print(tar_data.shape, tar_label.shape) # (n_tar, chns, fts) 275 | 276 | return src_data, src_label, tar_data, tar_label 277 | 278 | 279 | def data_normalize(fea_de, norm_type): 280 | if norm_type == 'zscore': 281 | zscore = preprocessing.StandardScaler() 282 | fea_de = zscore.fit_transform(fea_de) 283 | 284 | return fea_de 285 | 286 | 287 | def read_seed_all(args): 288 | # (15, 3394, 310) (15, 3394) 289 | if args.data_env == 'local': 290 | file = '/Users/wenz/dataset/MOABB/' + args.data + '.npz' 291 | if args.data_env == 'gpu': 292 | file = '/mnt/ssd2/wzw/data/bci/' + args.data + '.npz' 293 | 294 | MI = np.load(file) 295 | Data_raw, Label = MI['data'], MI['label'] 296 | 297 | data, label = [], [] 298 | for s in range(args.N): 299 | # each source sub 300 | fea_de = np.squeeze(Data_raw[s, :, :]) 301 | src_label = Label[s, :].reshape(-1, 1) 302 | data.append(fea_de) 303 | label.append(src_label) 304 | 305 | return data, label 306 | 307 | 308 | def read_seed_train(args): 309 | # (15, 3394, 310) (15, 3394) 310 | if args.data_env == 'local': 311 | file = '/Users/wenz/dataset/MOABB/' + args.data + '.npz' 312 | if args.data_env == 'gpu': 313 | file = '/mnt/ssd2/wzw/data/bci/' + args.data + '.npz' 314 | 315 | MI = np.load(file) 316 | Data_raw, Label = MI['data'], MI['label'] 317 | 318 | # source sub 319 | fea_de = np.squeeze(Data_raw[args.ids, :, :]) 320 | fea_de = data_normalize(fea_de, args.norm) 321 | fea_de = Variable(tr.from_numpy(fea_de).float()) 322 | 323 | src_label = np.squeeze(Label[args.ids, :]) 324 | src_label = tr.from_numpy(src_label).long() 325 | print(fea_de.shape, src_label.shape) 326 | 327 | return fea_de, src_label 328 | 329 | 330 | def read_seed_test(args): 331 | # (15, 3394, 310) (15, 3394) 332 | if args.data_env == 'local': 333 | file = '/Users/wenz/dataset/MOABB/' + args.data + '.npz' 334 | if args.data_env == 'gpu': 335 | file = '/mnt/ssd2/wzw/data/bci/' + args.data + '.npz' 336 | 337 | MI = np.load(file) 338 | Data_raw, Label = MI['data'], MI['label'] 339 | 340 | # target sub 341 | fea_de = np.squeeze(Data_raw[args.idt, :, :]) 342 | fea_de = data_normalize(fea_de, args.norm) 343 | fea_de = Variable(tr.from_numpy(fea_de).float()) 344 | 345 | tar_label = np.squeeze(Label[args.idt, :]) 346 | tar_label = tr.from_numpy(tar_label).long() 347 | print(fea_de.shape, tar_label.shape) 348 | 349 | return fea_de, tar_label 350 | 351 | 352 | def read_seed_combine(args): 353 | # (15, 3394, 310) (15, 3394) 354 | if args.data_env == 'local': 355 | file = '/Users/wenz/dataset/MOABB/' + args.data + '.npz' 356 | if args.data_env == 'gpu': 357 | file = '/mnt/ssd2/wzw/data/bci/' + args.data + '.npz' 358 | 359 | MI = np.load(file) 360 | Data_raw, Label = MI['data'], MI['label'] 361 | print(Data_raw.shape, Label.shape) 362 | 363 | n_sub = len(Data_raw) 364 | ids = np.delete(np.arange(0, n_sub), args.idt) 365 | src_data, src_label = [], [] 366 | for i in range(n_sub - 1): 367 | src_data.append(np.squeeze(Data_raw[ids[i], :, :])) 368 | src_label.append(np.squeeze(Label[ids[i], :])) 369 | 370 | fea_de = np.concatenate(src_data, axis=0) 371 | fea_de = data_normalize(fea_de, args.norm) 372 | fea_de = Variable(tr.from_numpy(fea_de).float()) 373 | 374 | src_label = np.concatenate(src_label, axis=0) 375 | src_label = tr.from_numpy(src_label).long() 376 | print(fea_de.shape, src_label.shape) 377 | 378 | return fea_de, src_label 379 | 380 | 381 | def read_seed_combine_tar(args): 382 | # (15, 3394, 310) (15, 3394) 383 | if args.data_env == 'local': 384 | file = '/Users/wenz/dataset/MOABB/' + args.data + '.npz' 385 | if args.data_env == 'gpu': 386 | file = '/mnt/ssd2/wzw/data/bci/' + args.data + '.npz' 387 | 388 | MI = np.load(file) 389 | Data_raw, Label = MI['data'], MI['label'] 390 | 391 | n_sub = len(Data_raw) 392 | ids = np.delete(np.arange(0, n_sub), args.idt) 393 | src_data, src_label = [], [] 394 | for i in range(n_sub - 1): 395 | src_data.append(np.squeeze(Data_raw[ids[i], :, :])) 396 | src_label.append(np.squeeze(Label[ids[i], :])) 397 | src_data = np.concatenate(src_data, axis=0) 398 | src_label = np.concatenate(src_label, axis=0) 399 | 400 | src_data = data_normalize(src_data, args.norm) 401 | src_data = Variable(tr.from_numpy(src_data).float()) 402 | src_label = tr.from_numpy(src_label).long() 403 | print(src_data.shape, src_label.shape) 404 | 405 | # target sub 406 | tar_data = np.squeeze(Data_raw[args.idt, :, :]) 407 | tar_data = data_normalize(tar_data, args.norm) 408 | tar_data = Variable(tr.from_numpy(tar_data).float()) 409 | tar_label = np.squeeze(Label[args.idt, :]) 410 | tar_label = tr.from_numpy(tar_label).long() 411 | print(tar_data.shape, tar_label.shape) 412 | 413 | return src_data, src_label, tar_data, tar_label 414 | 415 | 416 | def obtain_train_val_source(y_array, trial_ins_num, val_type): 417 | y_array = y_array.numpy() 418 | ins_num_all = len(y_array) 419 | src_idx = range(ins_num_all) 420 | 421 | if val_type == 'random': 422 | # 随机打乱会导致结果偏高,不管是MI还是SEED数据集 423 | num_train = int(0.9 * len(src_idx)) 424 | id_train, id_val = tr.utils.data.random_split(src_idx, [num_train, len(src_idx) - num_train]) 425 | 426 | if val_type == 'last': 427 | # 按顺序划分,一般情况来说没问题,但是如果源数据类别是按顺序排的,会有问题 428 | num_train = int(0.9 * trial_ins_num) 429 | id_train = np.array(src_idx).reshape(-1, trial_ins_num)[:, :num_train].reshape(1, -1).flatten() 430 | id_val = np.array(src_idx).reshape(-1, trial_ins_num)[:, num_train:].reshape(1, -1).flatten() 431 | 432 | return id_train, id_val 433 | --------------------------------------------------------------------------------