├── README.md └── Code ├── loss_function.py ├── test.py ├── valid.py ├── data_loader.py ├── train.py ├── utils.py └── net.py /README.md: -------------------------------------------------------------------------------- 1 | # Multi-modal Triplet Attention Network for Brain Diease Diagnosis 2 | A method for epilepsy diagnosis based on multi-modal brain images. 3 | 4 | ## File description 5 | - train.py -- Main file for setting and training model. 6 | - valid.py -- Evaluation for validation set. 7 | - test.py -- Evaluate the model. 8 | - net.py -- Networks of our method. 9 | - data_loader.py -- Load data from dataset 10 | - loss_function.py -- Triplet loss function used in our method. 11 | - utils.py -- Some functions utilized in our model. 12 | 13 | ## Prerequisite 14 | - Python environment with python 3.7 15 | - Install pytorch >= 1.8.0 16 | - Others 17 | - numpy, scipy, sklearn, einops, etc. 18 | -------------------------------------------------------------------------------- /Code/loss_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MyTriplet_loss(nn.Module): 6 | def __init__(self, margin=1.0, loss_weight=1.0): 7 | super(MyTriplet_loss, self).__init__() 8 | self.margin = margin 9 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 10 | self.loss_weight = loss_weight 11 | 12 | def forward(self, inputs, targets): 13 | # distances 14 | n = inputs.size(0) 15 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 16 | dist = dist + dist.t() 17 | dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2) 18 | dist = dist.clamp(min=1e-16).sqrt() 19 | ap_dist = dist.unsqueeze(2) 20 | an_dist = dist.unsqueeze(1) 21 | triplet_loss = ap_dist - an_dist + self.margin 22 | # triplets mask 23 | mask = get_mask(targets) 24 | # loss 25 | triplet_loss = torch.multiply(mask, triplet_loss) 26 | triplet_loss = torch.maximum(triplet_loss, torch.tensor(0.0)) 27 | num_triplets = torch.sum(torch.tensor(torch.greater(triplet_loss, 1e-16), dtype=torch.float32)) 28 | triplet_loss = torch.sum(triplet_loss) / (num_triplets + 1e-16) 29 | 30 | return triplet_loss 31 | 32 | 33 | def get_mask(targets): 34 | indices = torch.logical_not(torch.tensor(torch.eye(targets.shape[0]), dtype=torch.bool).cuda()) 35 | 36 | i_j = indices.unsqueeze(2) 37 | i_k = indices.unsqueeze(1) 38 | j_k = indices.unsqueeze(0) 39 | 40 | dist_indices = torch.logical_and(torch.logical_and(i_j, i_k), j_k) 41 | 42 | targets_equal = targets.unsqueeze(0).eq(targets.unsqueeze(1)) 43 | i_equal_j = targets_equal.unsqueeze(2) 44 | i_equal_k = targets_equal.unsqueeze(1) 45 | valid_labels = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k)) 46 | 47 | mask = torch.logical_and(valid_labels, dist_indices) 48 | 49 | return torch.tensor(mask, dtype=torch.float32) -------------------------------------------------------------------------------- /Code/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import statistics 3 | 4 | 5 | def test(test_loader, eval_loader, model, num_classes): 6 | predict, y_true = None, None 7 | trained_embedding, trained_label = None, None 8 | 9 | model.eval() 10 | with torch.no_grad(): 11 | for i, (inputs, label) in enumerate(eval_loader): 12 | inputs = inputs.cuda() 13 | label = label.cuda() 14 | 15 | train_embedding, _ = model(inputs) 16 | train_label = label.view(-1) 17 | if i == 0: 18 | trained_embedding = train_embedding 19 | trained_label = train_label 20 | else: 21 | trained_embedding = torch.vstack((trained_embedding, train_embedding)) 22 | trained_label = torch.hstack((trained_label, train_label)) 23 | 24 | for i, (inputs, label) in enumerate(test_loader): 25 | inputs = inputs.cuda() 26 | label = label.cuda() 27 | 28 | embedding, _ = model(inputs) 29 | 30 | for j, emb in enumerate(embedding): 31 | dist = torch.sum((emb - trained_embedding) ** 2, dim=-1) ** 0.5 32 | top_k = torch.argsort(dist)[:5] 33 | count = [0 for i in range(num_classes)] 34 | for k in trained_label[top_k]: 35 | count[k] += 1 36 | emb_label = torch.argmax(torch.tensor(count)) 37 | 38 | if i == 0 and j == 0: 39 | predict = emb_label 40 | else: 41 | predict = torch.hstack((predict, emb_label)) 42 | 43 | label = label.view(-1) 44 | 45 | if i == 0: 46 | pre = predict 47 | y_true = label 48 | else: 49 | pre = torch.hstack((pre, predict)) 50 | y_true = torch.hstack((y_true, label)) 51 | 52 | test_acc, test_sen, test_spe, test_auc = statistics(y_true.cpu(), pre.cpu()) 53 | 54 | return test_acc, test_sen, test_spe, test_auc -------------------------------------------------------------------------------- /Code/valid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import statistics 3 | 4 | 5 | def valid(train_loader, valid_loader, model, num_classes): 6 | predict, y_true = None, None 7 | trained_embedding, trained_label = None, None 8 | 9 | model.eval() 10 | with torch.no_grad(): 11 | for i, (inputs, label) in enumerate(train_loader): 12 | inputs = inputs.cuda() 13 | label = label.cuda() 14 | 15 | train_embedding, _ = model(inputs) 16 | train_label = label.view(-1) 17 | if i == 0: 18 | trained_embedding = train_embedding 19 | trained_label = train_label 20 | else: 21 | trained_embedding = torch.vstack((trained_embedding, train_embedding)) 22 | trained_label = torch.hstack((trained_label, train_label)) 23 | 24 | for i, (inputs, label) in enumerate(valid_loader): 25 | inputs = inputs.cuda() 26 | label = label.cuda() 27 | 28 | embedding, _ = model(inputs) 29 | 30 | for j, emb in enumerate(embedding): 31 | dist = torch.sum((emb - trained_embedding) ** 2, dim=-1) ** 0.5 32 | top_k = torch.argsort(dist)[:5] 33 | count = [0 for i in range(num_classes)] 34 | for k in trained_label[top_k]: 35 | count[k] += 1 36 | emb_label = torch.argmax(torch.tensor(count)) 37 | 38 | if i == 0 and j == 0: 39 | predict = emb_label 40 | else: 41 | predict = torch.hstack((predict, emb_label)) 42 | 43 | label = label.view(-1) 44 | 45 | if i == 0: 46 | pre = predict 47 | y_true = label 48 | else: 49 | pre = torch.hstack((pre, predict)) 50 | y_true = torch.hstack((y_true, label)) 51 | 52 | valid_acc, valid_sen, valid_spe, valid_auc = statistics(y_true.cpu(), pre.cpu()) 53 | 54 | return valid_acc, valid_sen, valid_spe, valid_auc -------------------------------------------------------------------------------- /Code/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as Data 4 | from utils import change_label, split_dataset 5 | 6 | 7 | class GetKfoldLoader(Data.Dataset): 8 | def __init__(self, data, datashape): 9 | super(GetKfoldLoader, self).__init__() 10 | self.data = data[:, :-1].reshape(-1, datashape[1], datashape[2]) 11 | self.label = data[:, -1].astype(np.int) 12 | self.label = change_label(self.label) 13 | 14 | self.data = torch.tensor(self.data, dtype=torch.float32) 15 | self.label = torch.tensor(self.label, dtype=torch.int64) 16 | 17 | def __len__(self): 18 | return self.data.shape[0] 19 | 20 | def __getitem__(self, item): 21 | return self.data[item], self.label[item] 22 | 23 | 24 | def load_data(data, label, batch_size, num_workers, ki=0, fold=10, valid_ratio=0.15): 25 | trainset, validset, testset = split_dataset(data, label, ki, fold, valid_ratio) 26 | evalset = np.concatenate((trainset, validset), axis=0) 27 | 28 | train_loader = GetKfoldLoader(data=trainset, datashape=data.shape) 29 | valid_loader = GetKfoldLoader(data=validset, datashape=data.shape) 30 | test_loader = GetKfoldLoader(data=testset, datashape=data.shape) 31 | eval_loader = GetKfoldLoader(data=evalset, datashape=data.shape) 32 | 33 | train_dataloader = Data.DataLoader( 34 | dataset=train_loader, 35 | batch_size=batch_size, 36 | num_workers=num_workers, 37 | shuffle=True, 38 | drop_last=False, 39 | ) 40 | valid_dataloader = Data.DataLoader( 41 | dataset=valid_loader, 42 | batch_size=batch_size, 43 | num_workers=num_workers, 44 | shuffle=False, 45 | drop_last=False, 46 | ) 47 | test_dataloader = Data.DataLoader( 48 | dataset=test_loader, 49 | batch_size=batch_size, 50 | num_workers=num_workers, 51 | shuffle=False, 52 | drop_last=False, 53 | ) 54 | eval_dataloader = Data.DataLoader( 55 | dataset=eval_loader, 56 | batch_size=batch_size, 57 | num_workers=num_workers, 58 | shuffle=False, 59 | drop_last=False, 60 | ) 61 | 62 | return train_dataloader, valid_dataloader, test_dataloader, eval_dataloader -------------------------------------------------------------------------------- /Code/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from utils import load_dataset 6 | from data_loader import load_data 7 | from loss_function import MyTriplet_loss 8 | from net import net 9 | from test import test 10 | from valid import valid 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | 15 | def train(model, criterions, optimizer, train_loader, valid_loader, fold, epochs, num_classes): 16 | train_size = len(train_loader) 17 | best_acc = 0 18 | best_model = None 19 | 20 | for epoch in range(epochs): 21 | start = time.time() 22 | model.train() 23 | epoch_loss = 0 24 | 25 | # train 26 | for i, (inputs, labels) in enumerate(train_loader): 27 | inputs = inputs.cuda() 28 | label = labels.view(-1).cuda() 29 | embedding, x = model(inputs) 30 | predict = model.frozen_forward(inputs) 31 | 32 | loss1 = criterions[0](predict, label) 33 | loss2 = criterions[1](embedding, label) 34 | loss = loss1 + loss2 35 | 36 | optimizer.zero_grad() 37 | loss.backward() 38 | optimizer.step() 39 | 40 | epoch_loss += loss.item() 41 | CE_loss += loss1.item() 42 | triplet_loss += loss2.item() 43 | 44 | # validation 45 | valid_acc, valid_sen, valid_spe, valid_auc = valid(train_loader, valid_loader, model, num_classes) 46 | 47 | if valid_acc >= best_acc: 48 | best_acc = valid_acc 49 | best_model = model 50 | torch.save(best_model, "best_epoch_fold{}.pkl".format(fold)) 51 | 52 | epoch_loss = epoch_loss / train_size 53 | 54 | end = time.time() - start 55 | print("< F{} {:.0f}% {}/{} {:.3f}s >".format(fold, (epoch + 1) / epochs * 100, epoch + 1, epochs, end), end="") 56 | print('train_loss =', '{:.5f}'.format(epoch_loss), end="") 57 | print('valid_acc =', '{:.4f}'.format(valid_acc * 100)) 58 | 59 | 60 | if __name__ == '__main__': 61 | file_path = r"../../Data/Epilepsy" 62 | file_name = [r"X_data_gnd", r"G_all"] 63 | 64 | # Parameters 65 | timepoints, rois = 240, 90 66 | dim, depth, heads = 256, 1, 1 67 | dropout = 0.5 68 | batch_size = 128 69 | epochs = 200 70 | num_classes = 2 71 | 72 | # tasks 73 | pick = [0, 1] 74 | # pick = [0, 2] 75 | # pick = [1, 2] 76 | # pick = [0, 1, 2] 77 | 78 | # get data 79 | data, label = load_dataset(file_path, file_name, pick=pick) 80 | valid_ratio = 0.2 81 | alpha, beta = 0.5, 0.5 82 | 83 | # k-fold validation 84 | predict_acc, predict_auc, predict_sen, predict_spe = [], [], [], [] 85 | 86 | K = 10 87 | for ki in range(K): 88 | 89 | train_loader, valid_loader, test_loader, eval_loader = load_data( 90 | data, label, batch_size=batch_size, num_workers=0, 91 | ki=ki, fold=K, valid_ratio=valid_ratio) 92 | 93 | model = net(rois, timepoints, num_classes, depth, heads, dropout).cuda() 94 | 95 | criterion1 = nn.CrossEntropyLoss().cuda() 96 | criterion2 = MyTriplet_loss(margin=0.8, loss_weight=alpha).cuda() 97 | criterions = [criterion1, criterion2] 98 | 99 | optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=5e-5) 100 | 101 | train(model, criterions, optimizer, train_loader, valid_loader, ki + 1, num_classes=num_classes, epochs=epochs) 102 | 103 | test_model = torch.load("best_epoch_fold{}.pkl".format(ki + 1)) 104 | acc, SEN, SPE, auc = test(test_loader, eval_loader, test_model, num_classes) 105 | 106 | predict_acc.append(acc) 107 | predict_auc.append(auc) 108 | predict_sen.append(SEN) 109 | predict_spe.append(SPE) 110 | -------------------------------------------------------------------------------- /Code/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as scio 3 | import numpy as np 4 | import torch 5 | from sklearn.metrics import accuracy_score, roc_auc_score 6 | from sklearn.preprocessing import StandardScaler 7 | 8 | 9 | def statistics(y_true, pre): 10 | acc, auc, sen, spe = 0.0, 0.0, 0.0, 0.0 11 | try: 12 | ACC = accuracy_score(y_true, pre) 13 | AUC = roc_auc_score(y_true, pre) 14 | TP = torch.sum(y_true & pre) 15 | TN = len(y_true) - torch.sum(y_true | pre) 16 | true_sum = torch.sum(y_true) 17 | neg_sum = len(y_true) - true_sum 18 | SEN = TP / true_sum 19 | SPE = TN / neg_sum 20 | 21 | acc += ACC 22 | sen += SEN.cpu().numpy() 23 | spe += SPE.cpu().numpy() 24 | auc += AUC 25 | 26 | except ValueError as ve: 27 | print(ve) 28 | pass 29 | 30 | return acc, sen, spe, auc 31 | 32 | 33 | def pick_data(data, pick): 34 | picked_data = [] 35 | for i in range(len(data)): 36 | if data[i][-1] in pick: 37 | picked_data.append(data[i]) 38 | return np.array(picked_data) 39 | 40 | 41 | def change_label(labels): 42 | if 0 not in labels: 43 | for i, label in enumerate(labels): 44 | labels[i] -= 1 45 | else: 46 | for i, label in enumerate(labels): 47 | if labels[i] != 0: 48 | labels[i] = 1 49 | return labels 50 | 51 | 52 | def load_dataset(file_path, file_name, pick=[1, 2]): 53 | epilepsy = scio.loadmat(os.path.join(file_path, file_name[0])) 54 | dti = scio.loadmat(os.path.join(file_path, file_name[1]))['G'].transpose(2, 0, 1) 55 | mri = epilepsy["data"] 56 | gnd = epilepsy['gnd'][0, :] 57 | label = gnd.reshape((gnd.shape[0], -1)) 58 | 59 | b, h, w = mri.shape 60 | _, n, m = dti.shape 61 | scaler = StandardScaler() 62 | mri = mri.reshape(b, -1) 63 | dti = dti.reshape(b, -1) 64 | mri = scaler.fit_transform(mri) 65 | dti = scaler.fit_transform(dti) 66 | data = np.concatenate((mri, dti, label), axis=1) 67 | 68 | data = pick_data(data, pick) 69 | mri = data[:, :h * w] 70 | dti = data[:, h * w: -1] 71 | 72 | mri = mri.reshape(-1, h, w) 73 | dti = dti.reshape(-1, n, m) 74 | label = data[:, -1] 75 | return np.concatenate((mri, dti), axis=2), label 76 | 77 | 78 | def split_dataset(data, label, ki, K, valid_ratio=0.15): 79 | test = [] 80 | index = [x for x in range(K)] 81 | test_index = index.pop(ki) 82 | classes = list(set(label)) 83 | num_classes = len(classes) 84 | class_index = [[] for i in range(num_classes)] 85 | for i, x in enumerate(label): 86 | class_index[classes.index(x)].append(i) 87 | 88 | for x in class_index: 89 | np.random.shuffle(x) 90 | sample_index = [x for x in range(0, label.shape[0])] 91 | 92 | every_k_len = [] 93 | for x in class_index: 94 | if len(x) / K - len(x) // K < 0.5: 95 | every_k_len.append(len(x) // K) 96 | else: 97 | every_k_len.append(len(x) // K + 1) 98 | 99 | for i, x in enumerate(class_index): 100 | if test_index != K - 1: 101 | test.extend(x[every_k_len[i] * test_index: every_k_len[i] * (test_index + 1)]) 102 | else: 103 | test.extend(x[every_k_len[i] * test_index:]) 104 | 105 | test_flag = torch.tensor([True if x in test else False for x in sample_index]) 106 | 107 | train_index = list(set(sample_index) - set(test)) 108 | extract_valid = np.random.randint(len(train_index), size=int(len(train_index) * valid_ratio)) 109 | valid_index = [index if i in extract_valid else -1 for i, index in enumerate(train_index)] 110 | valid_index = list(set(valid_index) - {-1}) 111 | valid_flag = torch.tensor([True if x in valid_index else False for x in sample_index]) 112 | train_index = list(set(train_index) - set(valid_index)) 113 | train_flag = torch.tensor([True if x in train_index else False for x in sample_index]) 114 | 115 | b, h, w = data.shape 116 | data = data.reshape(b, -1) 117 | data_label = np.concatenate((data, label.reshape(b, -1)), axis=1) 118 | trainset = data_label[train_flag] 119 | testset = data_label[test_flag] 120 | validset = data_label[valid_flag] 121 | np.random.shuffle(trainset) 122 | np.random.shuffle(testset) 123 | np.random.shuffle(validset) 124 | 125 | return trainset, validset, testset -------------------------------------------------------------------------------- /Code/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import einsum 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | 8 | class FeedForward(nn.Module): 9 | def __init__(self, dim, hidden_dim, dropout): 10 | super().__init__() 11 | self.net = nn.Sequential( 12 | nn.Linear(dim, hidden_dim), 13 | nn.GELU(), 14 | nn.Dropout(dropout), 15 | nn.Linear(hidden_dim, dim), 16 | ) 17 | self.norm = nn.LayerNorm(dim) 18 | 19 | def forward(self, x): 20 | residual = x 21 | x = self.net(x) 22 | x = self.norm(x + residual) 23 | return x 24 | 25 | 26 | class MyAttention(nn.Module): 27 | def __init__(self, dim, heads, dim_head): 28 | super(MyAttention, self).__init__() 29 | self.heads = heads 30 | self.scale = dim_head ** -0.5 31 | self.to_Q = nn.Linear(dim, dim_head * heads, bias=False) 32 | self.to_K = nn.Linear(dim, dim_head * heads, bias=False) 33 | self.to_V = nn.Linear(dim, dim_head * heads, bias=False) 34 | self.norm = nn.LayerNorm(dim) 35 | 36 | def forward(self, x, y): 37 | residual = x 38 | b, n_x, d_x, h_x = *x.shape, self.heads 39 | b, n_y, d_y, h_y = *y.shape, self.heads 40 | 41 | q = self.to_Q(x).view(b, -1, self.heads, d_y // self.heads).transpose(1, 2) 42 | k = self.to_K(y).view(b, -1, self.heads, d_x // self.heads).transpose(1, 2) 43 | v = self.to_V(y).view(b, -1, self.heads, d_x // self.heads).transpose(1, 2) 44 | 45 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 46 | attn = dots.softmax(dim=-1) 47 | 48 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 49 | out = rearrange(out, 'b h n d -> b n (h d)') 50 | out = self.norm(out + residual) 51 | return out 52 | 53 | 54 | class MyCrossAttention(nn.Module): 55 | def __init__(self, rois, dim, heads, dim_head, dropout): 56 | super(MyCrossAttention, self).__init__() 57 | self.heads = heads 58 | self.scale = dim_head ** -0.5 59 | 60 | self.to_Q = nn.Linear(rois, rois, bias=False) 61 | self.to_K = nn.Linear(dim, dim_head * heads, bias=False) 62 | self.to_V = nn.Linear(dim, dim_head * heads, bias=False) 63 | self.norm = nn.LayerNorm(dim) 64 | 65 | def forward(self, x, y): 66 | residual = y 67 | b, n_x, d_x, h_x = *x.shape, self.heads 68 | b, n_y, d_y, h_y = *y.shape, self.heads 69 | 70 | q = self.to_Q(x).view(b, -1, 1, d_x).transpose(1, 2) 71 | q = q.repeat(1, self.heads, 1, 1) 72 | k = self.to_K(y).view(b, -1, self.heads, d_y // self.heads).transpose(1, 2) 73 | v = self.to_V(y).view(b, -1, self.heads, d_y // self.heads).transpose(1, 2) 74 | 75 | kkt = einsum('b h i d, b h j d -> b h i j', k, k) * self.scale 76 | dots = einsum('b h i d, b h j d -> b h i j', q, kkt) * self.scale 77 | attn = dots.softmax(dim=-1) 78 | 79 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 80 | out = rearrange(out, 'b h n d -> b n (h d)') 81 | out = self.norm(out + residual) 82 | return out 83 | 84 | 85 | class MyDecoderLayer(nn.Module): 86 | def __init__(self, rois, dim, heads, dim_head, mlp_dim, dropout=0.5): 87 | super(MyDecoderLayer, self).__init__() 88 | self.SelfAttention = MyAttention(rois, heads=1, dim_head=rois) 89 | self.CrossAttention = MyCrossAttention(rois, dim, heads=heads, dim_head=dim_head, dropout=dropout) 90 | self.norm = nn.LayerNorm(dim) 91 | self.FeedForward = FeedForward(dim, mlp_dim, dropout) 92 | 93 | def forward(self, x, enc_out): 94 | out = self.SelfAttention(x, x) 95 | out = self.CrossAttention(out, enc_out) 96 | out = self.FeedForward(out) 97 | return out 98 | 99 | 100 | class MyDecoder(nn.Module): 101 | def __init__(self, rois, dim, depth, heads, dim_head, mlp_dim, dropout): 102 | super(MyDecoder, self).__init__() 103 | self.layers = nn.ModuleList( 104 | [MyDecoderLayer(rois, dim, heads, dim_head, mlp_dim, dropout=dropout) for _ in range(depth)]) 105 | 106 | def forward(self, x, enc_out): 107 | for layer in self.layers: 108 | x = layer(x, enc_out) 109 | return x 110 | 111 | 112 | class MyEncoderLayer(nn.Module): 113 | def __init__(self, dim, heads, dim_head, mlp_dim, dropout): 114 | super(MyEncoderLayer, self).__init__() 115 | self.SelfAttention = MyAttention(dim, heads, dim_head) 116 | self.norm = nn.LayerNorm(dim) 117 | self.FeedForward = FeedForward(dim, mlp_dim, dropout=dropout) 118 | 119 | def forward(self, x): 120 | x = self.SelfAttention(x, x) 121 | x = self.FeedForward(x) 122 | return x 123 | 124 | 125 | class MyEncoder(nn.Module): 126 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): 127 | super(MyEncoder, self).__init__() 128 | self.layers = nn.ModuleList( 129 | [MyEncoderLayer(dim, heads, dim_head, mlp_dim, dropout) for _ in range(depth)]) 130 | 131 | def forward(self, x): 132 | for layer in self.layers: 133 | x = layer(x) 134 | return x 135 | 136 | 137 | class net(nn.Module): 138 | def __init__(self, rois, timepoints, num_classes, depth, heads, dropout): 139 | super(net, self).__init__() 140 | self.dim = timepoints 141 | self.rois = rois 142 | mlp_dim = self.dim * 3 143 | 144 | self.encoder = MyEncoder(self.dim, depth, heads, self.dim // heads, mlp_dim, dropout) 145 | self.decoder = MyDecoder(rois, self.dim, depth, heads, self.dim // heads, mlp_dim, dropout) 146 | 147 | self.to_latent = nn.Identity() 148 | self.fc1 = nn.Sequential( 149 | nn.Linear(45 * 45, 1024), 150 | nn.LeakyReLU(), 151 | nn.Dropout(dropout), 152 | nn.Linear(1024, 128), 153 | nn.LeakyReLU(), 154 | ) 155 | self.mlp_head = nn.Sequential( 156 | nn.LayerNorm(128), 157 | nn.LeakyReLU(), 158 | nn.Linear(128, num_classes) 159 | ) 160 | self.norm = nn.LayerNorm(self.rois) 161 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=(2, 2), padding=(0, 0)) 162 | 163 | def forward(self, inputs): 164 | mri = inputs[:, :, : -self.rois] 165 | dti = inputs[:, :, -self.rois:] 166 | 167 | x = self.encoder(mri) 168 | x_out = torch.matmul(x, x.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.rois)) 169 | 170 | y = self.decoder(dti, x) 171 | y_out = torch.matmul(y, y.transpose(-1, -2) / torch.sqrt(torch.tensor(self.rois))) 172 | 173 | out = x_out + y_out 174 | out = self.maxpool(out) 175 | out = torch.flatten(out, start_dim=1) 176 | out = self.fc1(out) 177 | 178 | out_norm = F.normalize(out, p=2, dim=1) 179 | return out_norm, out 180 | 181 | def frozen_forward(self, x): 182 | with torch.no_grad(): 183 | _, x = self.forward(x) 184 | x = self.mlp_head(x) 185 | return torch.softmax(x, dim=-1) --------------------------------------------------------------------------------