└── LTSGAT.py /LTSGAT.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from grl import ReverseLayerF 3 | import numpy as np 4 | from einops import rearrange 5 | from Electrodes import Electrodes 6 | from torch_geometric.data import InMemoryDataset, Data 7 | from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp 8 | from torch_geometric.nn import GATConv,SGConv,TAGConv 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.autograd import Function 12 | from typing import Any, Optional, Tuple 13 | 14 | 15 | 16 | class Attention_Layer1(nn.Module): 17 | 18 | # 用来实现mask-attention layer 19 | def __init__(self): 20 | super(Attention_Layer1, self).__init__() 21 | # 下面使用nn的Linear层来定义Q,K,V矩阵 22 | self.Q_linear = nn.Linear(32, 32, bias=False) 23 | self.K_linear = nn.Linear(32, 32, bias=False) 24 | self.V_linear = nn.Linear(32, 32, bias=False) 25 | 26 | def forward(self, data): 27 | # 计算生成QKV矩阵 28 | att_input = data.x 29 | att_input = rearrange(att_input, '(b i) sl -> b sl i', i=32) 30 | Q = self.Q_linear(att_input) 31 | K = self.K_linear(att_input).permute(0,2,1) # 先进行一次转置 32 | V = self.V_linear(att_input) 33 | 34 | # 下面开始计算啦 35 | alpha = torch.matmul(Q, K) 36 | # 下面开始softmax 37 | alpha = F.softmax(alpha, dim=2) 38 | out = torch.matmul(alpha, V) 39 | return out 40 | 41 | """ 42 | class GradientReverseFunction(Function): 43 | 44 | # 写自定义的梯度计算方式 45 | 46 | @staticmethod 47 | def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor: 48 | ctx.coeff = coeff 49 | output = input * 1.0 50 | return output 51 | 52 | @staticmethod 53 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: 54 | return grad_output.neg() * ctx.coeff, None 55 | """ 56 | 57 | class BiLSTM(nn.Module): 58 | def __init__(self,input_size,hidden_size):#, sequence_length 59 | super(BiLSTM, self).__init__() 60 | self.hidden_size= hidden_size 61 | 62 | #对应特征维度 63 | self.input_size = input_size 64 | 65 | self.lstm = nn.LSTM(input_size, hidden_size,batch_first=True, bidirectional=True) 66 | 67 | 68 | def forward(self, x): 69 | # x = data.x 70 | x = rearrange(x,'b sl i -> b i sl',i=32) 71 | x, _ = self.lstm(x)#x (b,32,128) 72 | return x 73 | 74 | 75 | def Region_apart(x): 76 | region1 = torch.cat((x[:,0,:],x[:,1,:],x[:,16,:],x[:,17,:]),dim=1) 77 | region2 = torch.cat((x[:, 2, :], x[:, 3, :], x[:, 18, :], x[:, 19, :],x[:, 20, :]), dim=1) 78 | region3 = torch.cat((x[:,7, :], x[:, 25, :]), dim=1) 79 | region4 = torch.cat((x[:, 6, :], x[:,23, :], x[:, 24, :]), dim=1) 80 | region5 = torch.cat((x[:, 4, :], x[:, 5, :], x[:, 21, :], x[:,22, :]), dim=1) 81 | region6 = torch.cat((x[:, 8, :], x[:,9, :], x[:, 26, :], x[:, 27, :]), dim=1) 82 | region7 = torch.cat((x[:, 10, :], x[:, 11, :], x[:, 15, :], x[:,28, :],x[:, 29, :]), dim=1) 83 | region8 = torch.cat((x[:,12, :], x[:,30, :]), dim=1) 84 | region9 = torch.cat((x[:,13, :], x[:,14, :], x[:,31, :]), dim=1) 85 | # region1 = region1.reshape(batch, 4, 128) 86 | # region2 = region2.reshape(batch, 5, 128) 87 | # region3 = region3.reshape(batch, 2, 128) 88 | # region4 = region4.reshape(batch, 3, 128) 89 | # region5 = region5.reshape(batch, 4, 128) 90 | # region6 = region6.reshape(batch, 4, 128) 91 | # region7 = region7.reshape(batch, 5, 128) 92 | # region8 = region8.reshape(batch, 2, 128) 93 | # region9 = region9.reshape(batch, 3, 128) 94 | return region1,region2,region3,region4,region5,region6,region7,region8,region9 95 | 96 | class Attention_Layer(nn.Module): 97 | 98 | # 用来实现mask-attention layer 99 | def __init__(self): 100 | super(Attention_Layer, self).__init__() 101 | # 下面使用nn的Linear层来定义Q,K,V矩阵 102 | self.Q_linear = nn.Linear(256//2, 256//2, bias=False) 103 | self.K_linear = nn.Linear(256//2, 256//2, bias=False) 104 | self.V_linear = nn.Linear(256//2, 256//2, bias=False) 105 | 106 | def forward(self, att_input): 107 | # 计算生成QKV矩阵 108 | Q = self.Q_linear(att_input) 109 | K = self.K_linear(att_input).permute(0, 2, 1) # 先进行一次转置 110 | V = self.V_linear(att_input) 111 | 112 | # 下面开始计算啦 113 | alpha = torch.matmul(Q, K) 114 | # 下面开始softmax 115 | alpha = F.softmax(alpha, dim=2) 116 | out = torch.matmul(alpha, V) 117 | return out 118 | 119 | class GAT(torch.nn.Module): 120 | def __init__(self, num_features): 121 | super(GAT, self).__init__() 122 | self.gat1 = GATConv(num_features, 16, heads=4,dropout=0.6) 123 | self.gat2 = GATConv(64, 16, heads=4,dropout=0.6)# 124 | self.gat3 = GATConv(64, 16, heads=4,dropout=0.6) 125 | self.gat4 = GATConv(64, 64, heads=1,dropout=0.6) 126 | # self.dropout1 = nn.Dropout(0.5) 127 | # self.dropout2 = nn.Dropout(0.5) 128 | # self.dropout3 = nn.Dropout(0.5) 129 | # self.dropout4 = nn.Dropout(0.5) 130 | # self.fc = nn.Linear(16, 2) 131 | 132 | def forward(self,data): 133 | x,edge_index,batch = data.x,data.edge_index, data.batch 134 | x = F.leaky_relu(self.gat1(x, edge_index)) 135 | 136 | x = F.leaky_relu(self.gat2(x, edge_index)) 137 | 138 | x = F.leaky_relu(self.gat3(x, edge_index)) 139 | # 140 | x = F.leaky_relu(self.gat4(x, edge_index)) 141 | 142 | x = gap(x,batch) 143 | # x = torch.cat((gmp(x, batch), gap(x, batch)), dim=1) 144 | # logits = self.fc(x) 145 | # probas = F.softmax(logits,dim=1)#, dim=1 146 | return x 147 | 148 | class mynet(nn.Module): 149 | def __init__(self): 150 | super(mynet, self).__init__() 151 | self.attention1 = Attention_Layer1() 152 | self.LSTM = BiLSTM(10,64//2) 153 | self.dropout1 = nn.Dropout(0.5) 154 | 155 | 156 | self.linear1 = nn.Linear(512//2,256//2) 157 | self.linear2 = nn.Linear(640//2,256//2) 158 | self.linear3 = nn.Linear(256//2,256//2) 159 | self.linear4 = nn.Linear(384//2,256//2) 160 | self.linear5 = nn.Linear(512//2,256//2) 161 | self.linear6 = nn.Linear(512//2,256//2) 162 | self.linear7 = nn.Linear(640//2,256//2) 163 | self.linear8 = nn.Linear(256//2,256//2) 164 | self.linear9 = nn.Linear(384//2,256//2) 165 | 166 | 167 | self.attention = Attention_Layer() 168 | 169 | 170 | self.linear_1 = nn.Linear(256//2, 512//2) 171 | self.linear_2 = nn.Linear(256//2, 640//2) 172 | self.linear_3 = nn.Linear(256//2, 256//2) 173 | self.linear_4 = nn.Linear(256//2, 384//2) 174 | self.linear_5 = nn.Linear(256//2, 512//2) 175 | self.linear_6 = nn.Linear(256//2, 512//2) 176 | self.linear_7 = nn.Linear(256//2, 640//2) 177 | self.linear_8 = nn.Linear(256//2, 256//2) 178 | self.linear_9 = nn.Linear(256//2, 384//2) 179 | 180 | # self.layernorm1 = nn.LayerNorm(256) 181 | self.dropout2 = nn.Dropout(0.5) 182 | # # self.dropout2 = nn.Dropout(0.5) 183 | self.batchnorm2 = nn.BatchNorm1d(32) 184 | # # self.layernorm2 = nn.LayerNorm(128) 185 | 186 | self.gconv = GAT(128//2) 187 | # self.dropout3 = nn.Dropout(0.5) 188 | 189 | 190 | 191 | 192 | 193 | self.class_classifier = nn.Sequential() 194 | self.class_classifier.add_module('c_fc1', nn.Linear(64, 64)) 195 | self.class_classifier.add_module('c_bn1', nn.BatchNorm1d(64)) 196 | self.class_classifier.add_module('c_leaky_relu1', nn.LeakyReLU(True)) 197 | # self.class_classifier.add_module('c_drop1', nn.Dropout2d()) 198 | self.class_classifier.add_module('c_fc2', nn.Linear(64, 64)) 199 | self.class_classifier.add_module('c_bn2', nn.BatchNorm1d(64)) 200 | self.class_classifier.add_module('c_leaky_relu2', nn.LeakyReLU(True)) 201 | self.class_classifier.add_module('c_fc3', nn.Linear(64, 2)) 202 | # self.class_classifier.add_module('c_softmax', nn.LogSoftmax(dim=1)) 203 | 204 | self.domain_classifier = nn.Sequential() 205 | self.domain_classifier.add_module('d_fc1', nn.Linear(64, 64)) 206 | self.domain_classifier.add_module('d_bn1', nn.BatchNorm1d(64)) 207 | self.domain_classifier.add_module('d_leaky_relu1', nn.LeakyReLU(True)) 208 | self.domain_classifier.add_module('d_fc2', nn.Linear(64, 2)) 209 | # self.domain_classifier.add_module('d_softmax', nn.LogSoftmax(dim=1)) 210 | 211 | 212 | def forward(self,data,alpha):# 213 | x = self.attention1(data) 214 | x = self.LSTM(x) 215 | x = self.dropout1(x) 216 | 217 | 218 | region1,region2,region3,region4,region5,region6,region7,region8,region9 = Region_apart(x) 219 | # 将九个区域的维度转换成一致 220 | out1 = F.leaky_relu(self.linear1(region1)) 221 | out2 = F.leaky_relu(self.linear2(region2)) 222 | out3 = F.leaky_relu(self.linear3(region3)) 223 | out4 = F.leaky_relu(self.linear4(region4)) 224 | out5 = F.leaky_relu(self.linear5(region5)) 225 | out6 = F.leaky_relu(self.linear6(region6)) 226 | out7 = F.leaky_relu(self.linear7(region7)) 227 | out8 = F.leaky_relu(self.linear8(region8)) 228 | out9 = F.leaky_relu(self.linear9(region9)) 229 | out = torch.cat((out1,out2,out3,out4,out5,out6,out7,out8,out9),dim=0) 230 | out = rearrange(out, '(b i) sl -> b i sl', i=9) 231 | 232 | 233 | #out = self.layernorm1(out) 234 | # 进行注意力计算 235 | out = self.attention(out) 236 | 237 | 238 | # # 经过注意力计算后转成原来的维度 239 | output1 = F.leaky_relu(self.linear_1(out[:, 0 ,:])) 240 | output2 = F.leaky_relu(self.linear_2(out[:, 1, :])) 241 | output3 = F.leaky_relu(self.linear_3(out[:, 2, :])) 242 | output4 = F.leaky_relu(self.linear_4(out[:, 3, :])) 243 | output5 = F.leaky_relu(self.linear_5(out[:, 4, :])) 244 | output6 = F.leaky_relu(self.linear_6(out[:, 5, :])) 245 | output7 = F.leaky_relu(self.linear_7(out[:, 6, :])) 246 | output8 = F.leaky_relu(self.linear_8(out[:, 7, :])) 247 | output9 = F.leaky_relu(self.linear_9(out[:, 8, :])) 248 | output = torch.cat((output1,output2,output3,output4,output5,output6,output7,output8,output9),dim=1) 249 | # 将九个脑区合并,32x128 250 | 251 | x = rearrange(output,'b (i sl) -> b i sl',i=32) 252 | 253 | x = self.dropout2(x) 254 | x = self.batchnorm2(x) 255 | 256 | x = rearrange(x, 'b i sl -> (b i) sl') 257 | #x = self.layernorm2(x) 258 | data.x = x 259 | feature = self.gconv(data) 260 | reverse_feature = ReverseLayerF.apply(feature, alpha) 261 | logits1 = self.class_classifier(feature) 262 | 263 | logits2 = self.domain_classifier(reverse_feature) 264 | class_output = F.softmax(logits1,dim=1) 265 | domain_output = F.softmax(logits2,dim=1) 266 | 267 | # return logits,probas 268 | 269 | # class_output = self.class_classifier(feature) 270 | # domain_output = self.domain_classifier(reverse_feature) 271 | 272 | return logits1,class_output,logits2,domain_output# 273 | 274 | def initialize_weights(self): 275 | for m in self.modules(): 276 | if isinstance(m, nn.LSTM): 277 | torch.nn.init.kaiming_normal_(m.weight,mode='fan_out', nonlinearity='leaky_relu') 278 | m.bias.data.zero_() 279 | elif isinstance(m, GAT): 280 | torch.nn.init.kaiming_normal_(m.weight,mode='fan_out', nonlinearity='leaky_relu') 281 | m.bias.data.zero_() 282 | elif isinstance(m, torch.nn.Linear): 283 | torch.nn.init.xavier_normal_(m.weight) 284 | if m.bias is not None: 285 | torch.nn.init.constant_(m.bias, val=0.0) 286 | 287 | --------------------------------------------------------------------------------