├── README.md ├── adv_layer.py ├── main.py └── models.py /README.md: -------------------------------------------------------------------------------- 1 | # GraphEmotionNet 2 | A Pytorch implementation of our paper "Adaptive Spatial-Temporal Aware Graph Learning for 3 | EEG-based Emotion Recognition". 4 | You can find the corresponding article here: [GraphEmotionNet](https://spj.science.org/doi/pdf/10.34133/cbsystems.0088) 5 | # Dataset 6 | Prepare dataset: [SEED](https://bcmi.sjtu.edu.cn/home/seed/seed.html) , [SEED-IV](https://bcmi.sjtu.edu.cn/home/seed/seed-iv.html) and [MDD](https://figshare.com/articles/dataset/EEG_Data_New/4244171/2) 7 | # Training 8 | The model definition is in file: models.py 9 | 10 | The code for the Domain Adaptation part is in the file adv_layer.py 11 | 12 | You can start training the model by running the main.py file 13 | # Citation 14 | If you find our work helps your research, please kindly consider citing our paper in your publications. 15 | 16 | @article{yeadaptive, 17 | title={Adaptive Spatial-Temporal Aware Graph Learning for EEG-based Emotion Recognition}, 18 | author={Ye, Weishan and Wang, Jiyuan and Chen, Lin and Dai, Lifei and Sun, Zhe and Liang, Zhen}, 19 | journal={Cyborg and Bionic Systems}, 20 | publisher={AAAS} 21 | } 22 | -------------------------------------------------------------------------------- /adv_layer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Dec 14 11:15:46 2021 4 | 5 | @author: mindlab 6 | """ 7 | import torch.nn as nn 8 | import torch 9 | from torch.autograd import Function 10 | import torch.nn.functional as F 11 | from typing import Optional, Any, Tuple 12 | import numpy as np 13 | 14 | class ReverseLayerF(Function): 15 | 16 | @staticmethod 17 | def forward(ctx, x, alpha): 18 | ctx.alpha = alpha 19 | return x.view_as(x) 20 | 21 | @staticmethod 22 | def backward(ctx, grad_output): 23 | output = grad_output.neg() * ctx.alpha 24 | return output, None 25 | 26 | class GradientReverseFunction(Function): 27 | 28 | @staticmethod 29 | def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor: 30 | ctx.coeff = coeff 31 | output = input * 1.0 32 | return output 33 | 34 | @staticmethod 35 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: 36 | return grad_output.neg() * ctx.coeff, None 37 | 38 | class WarmStartGradientReverseLayer(nn.Module): 39 | """Gradient Reverse Layer :math:`\mathcal{R}(x)` with warm start 40 | 41 | The forward and backward behaviours are: 42 | 43 | .. math:: 44 | \mathcal{R}(x) = x, 45 | 46 | \dfrac{ d\mathcal{R}} {dx} = - \lambda I. 47 | 48 | :math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule: 49 | 50 | .. math:: 51 | \lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo 52 | 53 | where :math:`i` is the iteration step. 54 | 55 | Parameters: 56 | - **alpha** (float, optional): :math:`α`. Default: 1.0 57 | - **lo** (float, optional): Initial value of :math:`\lambda`. Default: 0.0 58 | - **hi** (float, optional): Final value of :math:`\lambda`. Default: 1.0 59 | - **max_iters** (int, optional): :math:`N`. Default: 1000 60 | - **auto_step** (bool, optional): If True, increase :math:`i` each time `forward` is called. 61 | Otherwise use function `step` to increase :math:`i`. Default: False 62 | """ 63 | 64 | def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1., 65 | max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False): 66 | super(WarmStartGradientReverseLayer, self).__init__() 67 | self.alpha = alpha 68 | self.lo = lo 69 | self.hi = hi 70 | self.iter_num = 0 71 | self.max_iters = max_iters 72 | self.auto_step = auto_step 73 | 74 | def forward(self, input: torch.Tensor) -> torch.Tensor: 75 | """""" 76 | coeff = np.float32( 77 | 2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters)) 78 | - (self.hi - self.lo) + self.lo 79 | ) 80 | if self.auto_step: 81 | self.step() 82 | return GradientReverseFunction.apply(input, coeff) 83 | 84 | def step(self): 85 | """Increase iteration number :math:`i` by 1""" 86 | self.iter_num += 1 87 | 88 | def binary_accuracy(output: torch.Tensor, target: torch.Tensor) -> float: 89 | """Computes the accuracy for binary classification""" 90 | with torch.no_grad(): 91 | batch_size = target.size(0) 92 | pred = (output >= 0.5).float().t().view(-1) 93 | correct = pred.eq(target.view(-1)).float().sum() 94 | correct.mul_(100. / batch_size) 95 | return correct 96 | 97 | class Discriminator(nn.Module): 98 | def __init__(self,hidden_1): 99 | super(Discriminator,self).__init__() 100 | self.fc1=nn.Linear(hidden_1,hidden_1) 101 | self.fc2=nn.Linear(hidden_1,1) 102 | self.dropout1 = nn.Dropout(p=0.25) 103 | self.sigmoid = nn.Sigmoid() 104 | 105 | def forward(self,x): 106 | x=self.fc1(x) 107 | x=F.relu(x) 108 | # x=F.leaky_relu(x) 109 | x=self.dropout1(x) 110 | x=self.fc2(x) 111 | x=self.sigmoid(x) 112 | 113 | return x 114 | 115 | class DomainAdversarialLoss(nn.Module): 116 | r"""The `Domain Adversarial Loss `_ 117 | 118 | Domain adversarial loss measures the domain discrepancy through training a domain discriminator. 119 | Given domain discriminator :math:`D`, feature representation :math:`f`, the definition of DANN loss is 120 | 121 | .. math:: 122 | loss(\mathcal{D}_s, \mathcal{D}_t) &= \mathbb{E}_{x_i^s \sim \mathcal{D}_s} log[D(f_i^s)] \\ 123 | &+ \mathbb{E}_{x_j^t \sim \mathcal{D}_t} log[1-D(f_j^t)].\\ 124 | 125 | Parameters: 126 | - **domain_discriminator** (class:`nn.Module` object): A domain discriminator object, which predicts 127 | the domains of features. Its input shape is (N, F) and output shape is (N, 1) 128 | - **reduction** (string, optional): Specifies the reduction to apply to the output: 129 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 130 | ``'mean'``: the sum of the output will be divided by the number of 131 | elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` 132 | 133 | Inputs: f_s, f_t 134 | - **f_s** (tensor): feature representations on source domain, :math:`f^s` 135 | - **f_t** (tensor): feature representations on target domain, :math:`f^t` 136 | 137 | Shape: 138 | - f_s, f_t: :math:`(N, F)` where F means the dimension of input features. 139 | - Outputs: scalar by default. If :attr:``reduction`` is ``'none'``, then :math:`(N, )`. 140 | 141 | Examples:: 142 | >>> from dalib.modules.domain_discriminator import DomainDiscriminator 143 | >>> discriminator = DomainDiscriminator(in_feature=1024, hidden_size=1024) 144 | >>> loss = DomainAdversarialLoss(discriminator, reduction='mean') 145 | >>> # features from source domain and target domain 146 | >>> f_s, f_t = torch.randn(20, 1024), torch.randn(20, 1024) 147 | >>> output = loss(f_s, f_t) 148 | """ 149 | 150 | def __init__(self,hidden_1, reduction: Optional[str] = 'mean',max_iter=1000): 151 | super(DomainAdversarialLoss, self).__init__() 152 | self.grl = WarmStartGradientReverseLayer(alpha=1.0, lo=0., hi=1., max_iters=max_iter, auto_step=True) 153 | self.domain_discriminator = Discriminator(hidden_1) 154 | self.bce = nn.BCELoss(reduction=reduction) 155 | self.domain_discriminator_accuracy = None 156 | 157 | def forward(self, x): 158 | f = self.grl(x) 159 | d = self.domain_discriminator(f) 160 | source_num = int(len(x) / 2) 161 | d_s, d_t = d.chunk(2, dim=0) 162 | d_label_s = torch.ones(source_num, 1).to(x.device) 163 | d_label_t = torch.zeros(source_num, 1).to(x.device) 164 | self.domain_discriminator_accuracy = 0.5 * (binary_accuracy(d_s, d_label_s) + binary_accuracy(d_t, d_label_t)) 165 | return 0.5 * (self.bce(d_s, d_label_s) + self.bce(d_t, d_label_t)) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: voicebeer 4 | Date: 2020-09-14 01:01:51 5 | LastEditTime: 2021-12-28 01:46:52 6 | ''' 7 | # standard 8 | import torch 9 | import numpy as np 10 | from tqdm import tqdm 11 | import models 12 | import os 13 | import random 14 | from torch.optim import Adam,SGD,RMSprop 15 | from torch.autograd import Variable 16 | from sklearn import preprocessing 17 | import scipy.io as scio 18 | import torch.utils.data as Data 19 | from matplotlib import pyplot as plt 20 | from matplotlib import rcParams 21 | import csv 22 | import pandas as pd 23 | 24 | 25 | def setup_seed(seed): 26 | torch.manual_seed(seed) 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 30 | np.random.seed(seed) # Numpy module. 31 | random.seed(seed) # Python random module. 32 | torch.backends.cudnn.benchmark = False 33 | torch.backends.cudnn.deterministic = True 34 | 35 | 36 | def train_GEN_contrast(subject_id, parameter, net_params, source_loaders, target_loader): 37 | setup_seed(20) 38 | device = net_params['DEVICE'] 39 | model = models.GEN_contrastNet(net_params).to(device) 40 | optimizer = RMSprop(model.parameters(), lr=parameter['init_lr'], weight_decay=parameter['weight_decay']) 41 | best_acc = 0.0 42 | total_loss_curve = np.zeros((parameter['epochs'])) # 20230816 43 | total_celoss_curve = np.zeros((parameter['epochs'])) 44 | total_dannloss_curve = np.zeros((parameter['epochs'])) 45 | for epoch in range(parameter['epochs']): 46 | model.train() 47 | total_loss, total_num, target_bar = 0.0, 0, tqdm(target_loader) 48 | source_acc_total, target_acc_total = 0, 0 49 | 50 | total_celoss, total_dannloss = 0.0, 0.0 # 20230816 51 | 52 | train_source_iter = enumerate(source_loaders) 53 | for data_target, label_target in target_bar: 54 | _, (data_source, labels_source) = next(train_source_iter) 55 | data_source, labels_source = data_source.to(device), labels_source.to(device) 56 | data_target, labels_target = data_target.to(device), label_target.to(device) 57 | data_source, labels_source = Variable(data_source.cuda()), Variable(labels_source.cuda()) 58 | data_target, labels_target = Variable(data_target.cuda()), Variable(labels_target.cuda()) 59 | 60 | pred, domain_loss, Sloss, dloss = model(torch.cat((data_source, data_target))) 61 | 62 | source_pred = pred[0:len(data_source), :] 63 | target_pred = pred[len(data_source):, :] 64 | 65 | log_prob = torch.nn.functional.log_softmax(source_pred, dim=1) 66 | celoss = -torch.sum(log_prob * labels_source) / len(labels_source) 67 | loss = celoss + domain_loss + Sloss + dloss 68 | source_scores = source_pred.detach().argmax(dim=1) 69 | source_acc = (source_scores == labels_source.argmax(dim=1)).float().sum().item() 70 | source_acc_total += source_acc 71 | target_scores = target_pred.detach().argmax(dim=1) 72 | target_acc = (target_scores == labels_target.argmax(dim=1)).float().sum().item() 73 | target_acc_total += target_acc 74 | 75 | optimizer.zero_grad() 76 | loss.backward() 77 | optimizer.step() 78 | total_num += parameter['batch_size'] 79 | total_loss += loss.item() * parameter['batch_size'] 80 | epoch_train_loss = total_loss / total_num 81 | # 20230816 82 | total_celoss += celoss.item() * parameter['batch_size'] 83 | epoch_train_celoss = total_celoss / total_num 84 | total_dannloss += domain_loss.item() * parameter['batch_size'] 85 | epoch_train_dannloss = total_dannloss / total_num 86 | 87 | target_bar.set_description('sub:{} Train Epoch: [{}/{}] Loss: {:.4f} source_acc:{:.2f}% target_acc:{:.2f}%' 88 | .format(subject_id, epoch + 1, parameter['epochs'], epoch_train_loss, 89 | source_acc_total / total_num * 100, 90 | target_acc_total / total_num * 100)) 91 | total_loss_curve[epoch] = epoch_train_loss 92 | total_celoss_curve[epoch] = epoch_train_celoss 93 | total_dannloss_curve[epoch] = epoch_train_dannloss 94 | 95 | if best_acc < (target_acc_total / total_num): 96 | best_acc = (target_acc_total / total_num) 97 | # scheduler.step(epoch_train_loss) 98 | # os.chdir('E:\\model_result') 99 | # torch.save(model.state_dict(),'model'+str(subject_id)+'.pkl') 100 | return best_acc, total_loss_curve, total_celoss_curve, total_dannloss_curve 101 | 102 | 103 | def test_GEN(subject_id, epoch, model, target_loader, parameter): 104 | model.eval() 105 | target_acc_total, total_num, target_bar = 0.0, 0, tqdm(target_loader) 106 | for data_target, label_target in target_bar: 107 | pred, _, _, _ = model(data_target) 108 | target_scores = pred.detach().argmax(dim=1) 109 | target_acc = (target_scores == label_target.argmax(dim=1)).float().sum().item() 110 | target_acc_total += target_acc 111 | total_num += parameter['batch_size'] 112 | target_bar.set_description('sub:{} Train Epoch: [{}/{}] target_acc:{:.2f}%' 113 | .format(subject_id, epoch+1, parameter['epochs'], 114 | target_acc_total/total_num * 100)) 115 | return target_acc_total / total_num 116 | 117 | def AddContext(x, context, label=False, dtype='float32'): 118 | ret = [] 119 | assert context % 2 == 1, "context value error." 120 | 121 | cut = int(context / 2) 122 | if label: 123 | for p in range(len(x)): 124 | tData = x[p][cut:x[p].shape[0] - cut] 125 | ret.append(tData) 126 | # print(tData.shape) 127 | else: 128 | for p in range(len(x)): 129 | tData = np.zeros([x[p].shape[0] - 2 * cut, context, x[p].shape[1], x[p].shape[2]], dtype=dtype) 130 | for i in range(cut, x[p].shape[0] - cut): 131 | tData[i - cut] = x[p][i - cut:i + cut + 1] 132 | 133 | # print(tData.shape) 134 | ret.append(tData) 135 | return ret 136 | 137 | def get_dataset(test_id, session): 138 | session =session+1 139 | path = '/root/autodl-tmp/dataset/feature_for_net_session' + str(session) + '_LDS_de' 140 | os.chdir(path) 141 | feature_list_source_labeled = [] 142 | label_list_source_labeled = [] 143 | feature_list_target = [] 144 | label_list_target = [] 145 | min_max_scaler = preprocessing.MinMaxScaler(feature_range=(-1, 1)) 146 | #video_time = [235, 233, 206, 238, 185, 195, 237, 216, 265, 237, 235, 233, 235, 238, 206] 147 | index = 0 148 | for info in os.listdir(path): 149 | domain = os.path.abspath(path) 150 | info_ = os.path.join(domain, info) # 将路径与文件名结合起来就是每个文件的完整路径 151 | if session == 1: 152 | feature = scio.loadmat(info_)['dataset_session1']['feature'][0, 0] 153 | label = scio.loadmat(info_)['dataset_session1']['label'][0, 0] 154 | elif session == 2: 155 | feature = scio.loadmat(info_)['dataset_session2']['feature'][0, 0] 156 | label = scio.loadmat(info_)['dataset_session2']['label'][0, 0] 157 | else: 158 | feature = scio.loadmat(info_)['dataset_session3']['feature'][0, 0] 159 | label = scio.loadmat(info_)['dataset_session3']['label'][0, 0] 160 | 161 | feature = min_max_scaler.fit_transform(feature).astype('float32') 162 | feature = feature.reshape(feature.shape[0], 62, 5, order='F') 163 | trial_list = [] 164 | trial_label_list = [] 165 | ''' 166 | for video in range(len(video_time)): 167 | if video==0: 168 | trial = feature[0:np.cumsum(video_time[0:video + 1])[-1], :] 169 | trial_label = label[0:np.cumsum(video_time[0:video + 1])[-1], :] 170 | else: 171 | trial = feature[np.cumsum(video_time[0:video])[-1]:np.cumsum(video_time[0:video+1])[-1],:] 172 | trial_label = label[np.cumsum(video_time[0:video])[-1]:np.cumsum(video_time[0:video + 1])[-1], :] 173 | trial_list.append(trial) 174 | trial_label_list.append(trial_label) 175 | ''' 176 | feature = AddContext(trial_list,3) 177 | label = AddContext(trial_label_list, 3, label=True) 178 | feature = np.vstack(feature) 179 | label = np.vstack(label) 180 | 181 | one_hot_label_mat = np.zeros((len(label), 3)) 182 | for i in range(len(label)): 183 | if label[i] == 0: 184 | one_hot_label = [1, 0, 0] 185 | one_hot_label = np.hstack(one_hot_label).reshape(1, 3) 186 | one_hot_label_mat[i, :] = one_hot_label 187 | if label[i] == 1: 188 | one_hot_label = [0, 1, 0] 189 | one_hot_label = np.hstack(one_hot_label).reshape(1, 3) 190 | one_hot_label_mat[i, :] = one_hot_label 191 | if label[i] == 2: 192 | one_hot_label = [0, 0, 1] 193 | one_hot_label = np.hstack(one_hot_label).reshape(1, 3) 194 | one_hot_label_mat[i, :] = one_hot_label 195 | 196 | if index != test_id: 197 | ## source labeled data 198 | feature_labeled = feature 199 | label_labeled = one_hot_label_mat 200 | feature_list_source_labeled.append(feature_labeled) 201 | label_list_source_labeled.append(label_labeled) 202 | 203 | else: 204 | ## target labeled data 205 | feature_list_target.append(feature) 206 | label_list_target.append(one_hot_label_mat) 207 | label = one_hot_label_mat 208 | index += 1 209 | 210 | source_feature_labeled, source_label_labeled = np.vstack(feature_list_source_labeled), np.vstack(label_list_source_labeled) 211 | 212 | target_feature = feature_list_target[0] 213 | target_label = label_list_target[0] 214 | 215 | target_set = {'feature': target_feature, 'label': target_label} 216 | source_set_labeled = {'feature': source_feature_labeled, 'label': source_label_labeled} 217 | 218 | return target_set, source_set_labeled 219 | 220 | 221 | 222 | 223 | def cross_subject(target_set, source_set_labeled, session_id, subject_id, parameter, net_params): 224 | setup_seed(20) 225 | torch_dataset_test = Data.TensorDataset(torch.from_numpy(target_set['feature']), torch.from_numpy(target_set['label'])) 226 | torch_dataset_source_labeled = Data.TensorDataset(torch.from_numpy(source_set_labeled['feature']), torch.from_numpy(source_set_labeled['label'])) 227 | source_loaders = torch.utils.data.DataLoader(dataset=torch_dataset_source_labeled, 228 | batch_size=parameter['batch_size'], 229 | shuffle=True, 230 | drop_last=True) 231 | target_loader = torch.utils.data.DataLoader(dataset=torch_dataset_test, 232 | batch_size=parameter['batch_size'], 233 | shuffle=True, 234 | drop_last=True) 235 | 236 | 237 | acc = train_GEN_contrast(subject_id,parameter,net_params,source_loaders=source_loaders,target_loader=target_loader) 238 | return acc 239 | 240 | 241 | def main(parameter,net_params): 242 | # data preparation 243 | if not os.path.exists('figures'): 244 | os.mkdir('figures') 245 | if not os.path.exists('csvfile'): 246 | os.mkdir('csvfile') 247 | 248 | 249 | setup_seed(20) 250 | print('Model name: MS-MDAER. Dataset name: ', parameter['dataset_name']) 251 | print('BS: {}, epoch: {}'.format(parameter['batch_size'], parameter['epochs'])) 252 | # store the results 253 | 254 | 255 | # for session_id_main in range(3): 256 | session_id = 0 257 | for subject_id in range(15): 258 | csub = [] 259 | loss_curve = [] # 20230816 260 | celoss_curve = [] 261 | dannloss_curve = [] 262 | 263 | target_set, source_set_labeled = get_dataset(subject_id, session_id) 264 | result = cross_subject(target_set, source_set_labeled, session_id, subject_id, parameter, net_params) 265 | csub.append(result[0]) 266 | loss_curve.append(result[1])# 20230816 267 | celoss_curve.append(result[2]) 268 | dannloss_curve.append(result[3]) 269 | loss_curve = [num for sublist in loss_curve for num in sublist] 270 | celoss_curve = [num for sublist in celoss_curve for num in sublist] 271 | dannloss_curve = [num for sublist in dannloss_curve for num in sublist] 272 | 273 | loss_pd = pd.DataFrame(loss_curve) 274 | loss_pd.to_csv(os.path.join('/root','csvfile', f'id{subject_id + 1}_loss')) 275 | 276 | celoss_pd = pd.DataFrame(celoss_curve) 277 | celoss_pd.to_csv(os.path.join('/root', 'csvfile', f'id{subject_id + 1}_celoss')) 278 | 279 | dannloss_pd = pd.DataFrame(dannloss_curve) 280 | dannloss_pd.to_csv(os.path.join('/root', 'csvfile', f'id{subject_id + 1}_dannloss')) 281 | 282 | #print(loss_curve) 283 | plt.rc('font',family='Times New Roman') 284 | #xs = range(parameter['epochs']) 285 | plt.figure() 286 | plt.plot(loss_curve) 287 | plt.xlabel('Epoch') 288 | plt.ylabel('Loss') 289 | #plt.legend() 290 | plt.savefig(os.path.join('/root','figures', f'id{subject_id + 1}_loss.png')) 291 | 292 | plt.figure() 293 | plt.plot(celoss_curve) 294 | plt.xlabel('Epoch') 295 | plt.ylabel('CeLoss') 296 | #plt.legend() 297 | plt.savefig(os.path.join('/root','figures', f'id{subject_id + 1}_CeLoss.png')) 298 | 299 | 300 | plt.figure() 301 | plt.plot(dannloss_curve) 302 | plt.xlabel('Epoch') 303 | plt.ylabel('DannLoss') 304 | #plt.legend() 305 | plt.savefig(os.path.join('/root','figures', f'id{subject_id + 1}_DannLoss.png')) 306 | 307 | print("Cross-subject: ", csub) 308 | return csub, loss_curve, celoss_curve, dannloss_curve 309 | 310 | 311 | parameter = {'dataset_name':'seed3','epochs':1000, 'batch_size':96, 'init_lr':1e-3, 'weight_decay':1e-2} 312 | net_params = {'GLalpha': 1e-2, 'node_feature_hidden1': 5, 'node_feature_hidden2': 5, 'adv_alpha': 1, 'aug_type': 'nn', 313 | 'in_dim': 5, 'hidden_dim': 5, 'out_dim': 5, 'in_feat_dropout': 0.0, 'dropout': 0, 'n_layers': 2, 314 | 'readout': 'mean', 'graph_norm': True, 'batch_norm': True, 'residual': True, 'category_number': 3, 315 | 'DEVICE': 'cuda:0', 'K':2, 'num_of_timesteps': 3, 'num_of_vertices': 62, 'num_of_features': 5, 316 | } 317 | csub,loss_curve,celoss_curve,dannloss_curve = main(parameter, net_params) 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import adv_layer 5 | import math 6 | import random 7 | 8 | class TemporalAttention(nn.Module): 9 | ''' 10 | compute temporal attention scores 11 | -------- 12 | Input: (batch_size, num_of_timesteps, num_of_vertices, num_of_features) 13 | Output: (batch_size, num_of_timesteps, num_of_timesteps) 14 | ''' 15 | def __init__(self, device, num_of_timesteps, num_of_vertices, num_of_features): 16 | super(TemporalAttention, self).__init__() 17 | self.U1 = nn.init.normal_(nn.Parameter(torch.FloatTensor(num_of_vertices).to(device))) 18 | self.U2 = nn.init.normal_(nn.Parameter(torch.FloatTensor(num_of_features, num_of_vertices).to(device))) 19 | self.U3 = nn.init.normal_(nn.Parameter(torch.FloatTensor(num_of_features).to(device))) 20 | self.be = nn.init.normal_(nn.Parameter(torch.FloatTensor(1, num_of_timesteps, num_of_timesteps).to(device))) 21 | self.Ve = nn.init.normal_(nn.Parameter(torch.FloatTensor(num_of_timesteps, num_of_timesteps).to(device))) 22 | def forward(self, x): 23 | ''' 24 | :param x: (batch_size, T, V, F) 25 | :return: (B, T, T) 26 | ''' 27 | lhs = torch.matmul(torch.matmul(x.permute(0,1,3,2), self.U1), self.U2)#(bs,3,62: T,V) 28 | rhs = torch.matmul(self.U3, x.permute(0,2,3,1)) # (F)(B,V,F,T)->(B, V, T) 29 | product = torch.matmul(lhs, rhs) # (B,T,V)(B,V,T)->(B,T,T) 30 | E = torch.matmul(self.Ve, torch.sigmoid(product + self.be)) # (B, T, T) 31 | E_normalized = F.softmax(E, dim=1) 32 | return E_normalized 33 | 34 | def diff_loss(diff, S, Falpha): 35 | ''' 36 | compute the 1st loss of L_{graph_learning} 37 | ''' 38 | if len(S.shape)==4: 39 | # batch input 40 | return Falpha * torch.mean(torch.sum(torch.sum(diff**2,axis=3)*S, axis=(1,2))) 41 | else: 42 | return Falpha * torch.sum(torch.matmul(S,torch.sum(diff**2,axis=2))) 43 | 44 | def F_norm_loss(S, Falpha): 45 | ''' 46 | compute the 2nd loss of L_{graph_learning} 47 | ''' 48 | if len(S.shape)==3: 49 | # batch input 50 | return Falpha * torch.sum(torch.mean(S**2,axis=0)) 51 | else: 52 | return Falpha * torch.sum(S**2) 53 | 54 | class Graph_Learn(nn.Module): 55 | ''' 56 | Graph structure learning (based on the middle time slice) 57 | -------- 58 | Input: (batch_size, num_of_timesteps, num_of_vertices, num_of_features) 59 | Output: (batch_size, num_of_vertices, num_of_vertices) 60 | ''' 61 | def __init__(self,alpha, num_of_features, device): 62 | super(Graph_Learn, self).__init__() 63 | self.alpha = alpha 64 | self.a = nn.init.uniform_(nn.Parameter(torch.FloatTensor(num_of_features, 1).to(device))) 65 | self.S = torch.zeros(1,1,1,1) # similar to placeholder 66 | self.diff = torch.zeros(1,1,1,1,1) # similar to placeholder 67 | 68 | def forward(self, x): 69 | N, T, V, f = x.shape 70 | # shape: (N,V,F) use the current slice (middle one slice) 71 | x = x[:,int(x.shape[1])//2,:,:] 72 | # shape: (N,V,V,F) 73 | diff = (x.expand(V,N,V,f).permute(2,1,0,3)-x.expand(V,N,V,f)).permute(1,0,2,3)#62*61+62 74 | # shape: (N,V,V) 75 | tmpS = torch.exp(F.relu(torch.reshape(torch.matmul(torch.abs(diff), self.a), [N,V,V]))) 76 | # normalization 77 | S = tmpS / torch.sum(tmpS,axis=1,keepdims=True) 78 | self.diff = diff 79 | self.S = S 80 | Sloss = F_norm_loss(self.S,self.alpha) 81 | dloss = diff_loss(self.diff,self.S,self.alpha) 82 | return S,Sloss,dloss 83 | 84 | class SpatialAttention(nn.Module): 85 | ''' 86 | compute spatial attention scores 87 | -------- 88 | Input: (batch_size, num_of_timesteps, num_of_vertices, num_of_features) 89 | Output: (batch_size, num_of_vertices, num_of_vertices) 90 | ''' 91 | def __init__(self, device, num_of_timesteps, num_of_vertices, num_of_features): 92 | super(SpatialAttention, self).__init__() 93 | self.W1 = nn.init.normal_(nn.Parameter(torch.FloatTensor(num_of_timesteps).to(device))) 94 | self.W2 = nn.init.normal_(nn.Parameter(torch.FloatTensor(num_of_features, num_of_timesteps).to(device))) 95 | self.W3 = nn.init.normal_(nn.Parameter(torch.FloatTensor(num_of_features).to(device))) 96 | self.bs = nn.init.normal_(nn.Parameter(torch.FloatTensor(1, num_of_vertices,num_of_vertices).to(device))) 97 | self.Vs = nn.init.normal_(nn.Parameter(torch.FloatTensor(num_of_vertices, num_of_vertices).to(device))) 98 | def forward(self, x): 99 | ''' 100 | :param x: (batch_size, T, V, F) 101 | :return: (B,N,N) 102 | ''' 103 | lhs = torch.matmul(torch.matmul(x.permute(0,2,3,1), self.W1), self.W2) #(bs,62,3: V,T) 104 | rhs = torch.matmul(self.W3, x.permute(0,1,3,2)) #(bs,3,62: T,V) 105 | product = torch.matmul(lhs, rhs) # (b,V,T)(b,T,V) -> (B, V, V) 106 | S = torch.matmul(self.Vs, torch.sigmoid(product + self.bs)) # (V,V)(B, V, V)->(B,V,V) 107 | S_normalized = F.softmax(S, dim=1) 108 | return S_normalized 109 | 110 | class cheb_conv_with_SAt_GL(nn.Module): 111 | ''' 112 | K-order chebyshev graph convolution after Graph Learn 113 | -------- 114 | Input: [x (batch_size, num_of_timesteps, num_of_vertices, num_of_features), 115 | SAtt(batch_size, num_of_vertices, num_of_vertices), 116 | S (batch_size, num_of_vertices, num_of_vertices)] 117 | Output: (batch_size, num_of_timesteps, num_of_vertices, num_of_filters) 118 | ''' 119 | def __init__(self, num_of_filters, k, num_of_features, device): 120 | super(cheb_conv_with_SAt_GL, self).__init__() 121 | self.Theta = nn.ParameterList([nn.init.uniform_(nn.Parameter(torch.FloatTensor(num_of_features, num_of_filters).to(device))) for _ in range(k)]) 122 | self.out_channels = num_of_filters 123 | self.K = k 124 | self.device = device 125 | 126 | def forward(self, x): 127 | #Input: [x,SAtt,S] 128 | assert isinstance(x, list) 129 | assert len(x)==3,'cheb_conv_with_SAt_GL: number of input error' 130 | x, spatial_attention, W = x 131 | N, T, V, f = x.shape 132 | #Calculating Chebyshev polynomials 133 | D = torch.diag_embed(torch.sum(W,axis=1)) 134 | L = D - W 135 | ''' 136 | Here we approximate λ_{max} to 2 to simplify the calculation. 137 | For more general calculations, please refer to here: 138 | lambda_max = K.max(tf.self_adjoint_eigvals(L),axis=1) 139 | L_t = (2 * L) / tf.reshape(lambda_max,[-1,1,1]) - [tf.eye(int(num_of_vertices))] 140 | ''' 141 | 142 | lambda_max = 2.0 143 | L_t =( (2 * L) / lambda_max - torch.eye(int(V)).to(self.device)) 144 | cheb_polynomials = [torch.eye(int(V)).to(self.device), L_t] 145 | for i in range(2, self.K): 146 | cheb_polynomials.append(2 * L_t * cheb_polynomials[i - 1] - cheb_polynomials[i - 2]) 147 | #Graph Convolution 148 | outputs = [] 149 | for time_step in range(T): 150 | graph_signal = x[:, time_step, :, :] # (b, V, F_in) 151 | output = torch.zeros(N, V, self.out_channels).to(self.device) # (b, V, F_out) 152 | for k in range(self.K): 153 | T_k = cheb_polynomials[k] # (V,V) 154 | T_k_with_at = T_k.mul(spatial_attention) # (V,V)*(V,V) = (V,V) 多行和为1, 按着列进行归一化 155 | theta_k = self.Theta[k] # (in_channel, out_channel) 156 | rhs = T_k_with_at.permute(0, 2, 1).matmul(graph_signal) # (V, V)(b, V, F_in) = (b, V, F_in) 因为是左乘,所以多行和为1变为多列和为1,即一行之和为1,进行左乘 157 | output = output + rhs.matmul(theta_k) # (b, V, F_in)(F_in, F_out) = (b, V, F_out) 158 | outputs.append(output.unsqueeze(1)) # (b, 1, V, F_out) 159 | return F.relu(torch.cat(outputs, dim=1)) # (b, T, V, F_out) 160 | 161 | class cheb_conv_withSAt(nn.Module): 162 | ''' 163 | K-order chebyshev graph convolution 164 | ''' 165 | def __init__(self, K, cheb_polynomials, in_channels, out_channels): 166 | ''' 167 | :param K: int 168 | :param in_channles: int, num of channels in the input sequence 169 | :param out_channels: int, num of channels in the output sequence 170 | ''' 171 | super(cheb_conv_withSAt, self).__init__() 172 | self.K = K 173 | self.cheb_polynomials = cheb_polynomials 174 | self.in_channels = in_channels 175 | self.out_channels = out_channels 176 | self.DEVICE = torch.device('cuda:0')#torch.device('cuda:0')#cheb_polynomials[0].device 177 | self.Theta = nn.ParameterList([nn.Parameter(torch.FloatTensor(in_channels, out_channels).to(self.DEVICE)) for _ in range(K)]) 178 | 179 | def forward(self, x, spatial_attention): 180 | ''' 181 | Chebyshev graph convolution operation 182 | :param x: (batch_size, T, V, F_in) 183 | :return: (batch_size, T, ,V,F_out) 184 | ''' 185 | batch_size, num_of_timesteps, num_of_vertices, num_of_features = x.shape 186 | outputs = [] 187 | for time_step in range(num_of_timesteps): 188 | graph_signal = x[:, time_step, :, :] # (b, V, F_in) 189 | output = torch.zeros(batch_size, num_of_vertices, self.out_channels).to(self.DEVICE) # (b, V, F_out) 190 | for k in range(self.K): 191 | T_k = self.cheb_polynomials[k] # (V,V) 192 | T_k_with_at = T_k.mul(spatial_attention) # (V,V)*(V,V) = (V,V) 多行和为1, 按着列进行归一化 193 | theta_k = self.Theta[k] # (in_channel, out_channel) 194 | rhs = T_k_with_at.permute(0, 2, 1).matmul(graph_signal) # (V, V)(b, V, F_in) = (b, V, F_in) 因为是左乘,所以多行和为1变为多列和为1,即一行之和为1,进行左乘 195 | output = output + rhs.matmul(theta_k) # (b, V, F_in)(F_in, F_out) = (b, V, F_out) 196 | outputs.append(output.unsqueeze(1)) # (b, 1, V, F_out) 197 | return F.relu(torch.cat(outputs, dim=1)) # (b, T, V, F_out) 198 | 199 | class GEN_block(nn.Module): 200 | 201 | def __init__(self, net_params): 202 | super(GEN_block, self).__init__() 203 | self.num_of_timesteps = net_params['num_of_timesteps'] 204 | self.num_of_vertices = net_params['num_of_vertices'] 205 | self.num_of_features = net_params['num_of_features'] 206 | device = net_params['DEVICE'] 207 | node_feature_hidden1 = net_params['node_feature_hidden1'] 208 | node_feature_hidden2 = net_params['node_feature_hidden2'] 209 | self.TAt = TemporalAttention(device, self.num_of_timesteps, self.num_of_vertices, self.num_of_features) 210 | self.SAt = SpatialAttention(device, self.num_of_timesteps, self.num_of_vertices, self.num_of_features) 211 | self.Graph_Learn = Graph_Learn(net_params['GLalpha'], self.num_of_features, device) 212 | self.cheb_conv_SAt_GL = cheb_conv_with_SAt_GL(node_feature_hidden1,net_params['K'], self.num_of_features, device) 213 | self.time_conv = nn.Conv2d(node_feature_hidden1, node_feature_hidden2, kernel_size=(1, 3), stride=(1, 3)) 214 | self.residual_conv = nn.Conv2d(self.num_of_features, node_feature_hidden2, kernel_size=(1, 1), stride=(1, 3)) 215 | self.ln = nn.LayerNorm(node_feature_hidden2) #需要将channel放到最后一个维度上 216 | 217 | def forward(self, x): 218 | ''' 219 | x: input(bs,T,V,F) 220 | return:(bs,T,V,F2) 221 | ''' 222 | # TAt 223 | temporal_At = self.TAt(x) # (b, T, T) 224 | x_TAt = torch.matmul((x.permute(0,2,3,1)).reshape(x.shape[0],-1,self.num_of_timesteps), temporal_At).reshape(x.shape[0], self.num_of_vertices, self.num_of_features, self.num_of_timesteps) 225 | x_TAt = x_TAt.permute(0,3,1,2)/math.sqrt(62*5) 226 | # SAt 227 | spatial_At = self.SAt(x_TAt) 228 | S,Sloss,dloss = self.Graph_Learn(x) 229 | spatial_gcn = self.cheb_conv_SAt_GL([x, spatial_At, S]) 230 | # convolution along the time axis 231 | time_conv_output = self.time_conv(spatial_gcn.permute(0, 3, 2, 1))/math.sqrt(3*5) # (b,T,V,F)->(b,F,V,T) 用(1,3)的卷积核去做->(b,F,V,T) 232 | # residual shortcut 233 | x_residual = self.residual_conv(x.permute(0, 3, 2, 1)) # (b,T,V,F)->(b,F,V,T) 用(1,3)的卷积核去做->(b,F,V,T) 234 | x_residual = self.ln(F.relu(x_residual + time_conv_output).permute(0, 3, 2, 1))# (b,F,V,T)->(b,T,V,F) 235 | x_residual = x_residual.squeeze(1) 236 | return x_residual,Sloss,dloss,S 237 | 238 | class feature_extractor(nn.Module): 239 | def __init__(self,input, hidden_1,hidden_2): 240 | super(feature_extractor,self).__init__() 241 | self.fc1=nn.Linear(input,hidden_1) 242 | self.fc2=nn.Linear(hidden_1,hidden_2) 243 | self.dropout1 = nn.Dropout(p=0.25) 244 | self.dropout2 = nn.Dropout(p=0.25) 245 | def forward(self,x): 246 | x=self.fc1(x) 247 | x1=F.relu(x) 248 | # x=F.leaky_relu(x) 249 | x2=self.fc2(x1) 250 | x2=F.relu(x2) 251 | # x=F.leaky_relu(x) 252 | return x2 253 | 254 | 255 | 256 | class GEN_contrastNet(nn.Module): 257 | def __init__(self, net_params): 258 | super(GEN_contrastNet, self).__init__() 259 | self.device = net_params['DEVICE'] 260 | self.GEN = GEN_block(net_params) 261 | self.adv_alpha = net_params['adv_alpha'] 262 | self.domain_classifier = adv_layer.DomainAdversarialLoss(hidden_1=64) 263 | self.aug_type = net_params['aug_type'] 264 | #self.train_test = net_params['train'] 265 | #self.head = net_params['projection_head'] 266 | self.fea_extrator_f = feature_extractor(310, 64, 64) 267 | self.classifier_noproto = nn.Linear(64, 3) 268 | self.g_list = [] 269 | def forward(self, x): 270 | 271 | ##时空图特征提取 272 | feature, Sloss, dloss, S = self.GEN(x) #时空图特征提取 273 | feature1 = torch.flatten(feature, start_dim=1, end_dim=-1) 274 | feature1 = self.fea_extrator_f(feature1) 275 | pred = self.classifier_noproto(feature1) 276 | domain_output = self.domain_classifier(feature1) 277 | return pred, domain_output, Sloss, dloss 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | --------------------------------------------------------------------------------