├── LICENSE ├── Model.py ├── Utils.py ├── README.md ├── Dataset.py ├── RBFSVM_Kfold.py ├── linearSVM_Loso.py ├── RBFSVM_Loso.py ├── linearSVM_Kfold.py ├── GCN_Kfold.py └── GCN_Loso.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 alien18 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as func 3 | from torch_geometric.nn import ChebConv, global_mean_pool 4 | 5 | 6 | class GCN(torch.nn.Module): 7 | """GCN model(network architecture can be modified)""" 8 | 9 | def __init__(self, 10 | num_features, 11 | num_classes, 12 | k_order, 13 | dropout=.5): 14 | super(GCN, self).__init__() 15 | 16 | self.p = dropout 17 | 18 | self.conv1 = ChebConv(int(num_features), 64, K=k_order) 19 | self.conv2 = ChebConv(64, 64, K=k_order) 20 | self.conv3 = ChebConv(64, 128, K=k_order) 21 | 22 | self.lin1 = torch.nn.Linear(128, int(num_classes)) 23 | 24 | def forward(self, data): 25 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 26 | batch = data.batch 27 | 28 | x = func.relu(self.conv1(x, edge_index, edge_attr)) 29 | x = func.dropout(x, p=self.p, training=self.training) 30 | x = func.relu(self.conv2(x, edge_index, edge_attr)) 31 | x = func.dropout(x, p=self.p, training=self.training) 32 | x = func.relu(self.conv3(x, edge_index, edge_attr)) 33 | 34 | x = global_mean_pool(x, batch) 35 | x = self.lin1(x) 36 | return x 37 | -------------------------------------------------------------------------------- /Utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial import distance 3 | from scipy.sparse import coo_matrix, csr 4 | 5 | 6 | def compute_KNN_graph(matrix, k_degree=10, metric='euclidean'): 7 | """ Calculate the adjacency matrix from the connectivity matrix.""" 8 | 9 | dist = distance.pdist(matrix, metric) 10 | dist = distance.squareform(dist) 11 | 12 | idx = np.argsort(dist)[:, 1:k_degree + 1] 13 | dist.sort() 14 | dist = dist[:, 1:k_degree + 1] 15 | 16 | A = adjacency(dist, idx).astype(np.float32) 17 | 18 | return A 19 | 20 | 21 | def adjacency(dist, idx): 22 | 23 | m, k = dist.shape 24 | assert m, k == idx.shape 25 | assert dist.min() >= 0 26 | 27 | # Weights. 28 | sigma2 = np.mean(dist[:, -1]) ** 2 29 | dist = np.exp(- dist ** 2 / sigma2) 30 | 31 | # Weight matrix. 32 | I = np.arange(0, m).repeat(k) 33 | J = idx.reshape(m * k) 34 | V = dist.reshape(m * k) 35 | W = coo_matrix((V, (I, J)), shape=(m, m)) 36 | 37 | # No self-connections. 38 | W.setdiag(0) 39 | 40 | # Non-directed graph. 41 | bigger = W.T > W 42 | W = W - W.multiply(bigger) + W.T.multiply(bigger) 43 | 44 | assert W.nnz % 2 == 0 45 | assert np.abs(W - W.T).mean() < 1e-10 46 | assert type(W) is csr.csr_matrix 47 | return W.todense() 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GCN_SCZ_Classification 2 | 3 | This repository provides core codes and toolboxes for analysis in the paper entitled "Graph convolutional networks reveal network-level functional dysconnectivity in schizophrenia" by Lei et al. Please see the paper for the description of data analysis. 4 | 5 | # Overview 6 | Content includes demo data and source code for the implementation of graph convolutional network (GCN), linear support vector machine (SVM) and non-linear SVM with radial basis function (RBF) kernel on a large multi-site schizophrenia fMRI dataset. All custom codes were tested on Linux Ubuntu 20.04 LTS PC. 7 | 8 | # Requirements 9 | - Python (>= 3.5) 10 | - Scikit-Learn 11 | - Pytorch 12 | - Pytorch-geometric 13 | - Scipy 14 | - Numpy 15 | - Pandas 16 | 17 | # Toolboxes 18 | 19 | All other toolboxes and codes used in our study for image preprocessing, harmonization, ancillary analysis and visualization are shown below: 20 | 21 | - [SPM12](https://www.fil.ion.ucl.ac.uk/spm/software/spm12/) 22 | - [NeuroComBat-sklearn](https://github.com/Warvito/neurocombat_sklearn) 23 | - [Neuroharmony](https://github.com/garciadias/Neuroharmony) 24 | - [GraphSaliencyMap](https://github.com/sarslancs/graph_saliency_maps) 25 | - [GRETNA](https://www.nitrc.org/projects/gretna/) 26 | - [BrainNetViewer](https://www.nitrc.org/projects/bnv/) 27 | -------------------------------------------------------------------------------- /Dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from torch_geometric.data import InMemoryDataset, Data 6 | from torch_geometric.utils import dense_to_sparse 7 | 8 | from Utils import compute_KNN_graph 9 | 10 | 11 | class ConnectivityData(InMemoryDataset): 12 | """ Dataset for the connectivity data.""" 13 | 14 | def __init__(self, 15 | root): 16 | super(ConnectivityData, self).__init__(root, None, None) 17 | self.data, self.slices = torch.load(self.processed_paths[0]) 18 | 19 | @property 20 | def raw_file_names(self): 21 | file_paths = sorted(list(Path(self.raw_dir).glob("*.txt"))) 22 | return [str(file_path.name) for file_path in file_paths] 23 | 24 | @property 25 | def processed_file_names(self): 26 | return 'data.pt' 27 | 28 | def set_new_indices(self): 29 | self.__indices__ = list(range(self.len())) 30 | 31 | def process(self): 32 | labels = np.genfromtxt(Path(self.raw_dir) / "Labels.csv") 33 | 34 | data_list = [] 35 | for filename, y in zip(self.raw_paths, labels): 36 | y = torch.tensor([y]).long() 37 | connectivity = np.genfromtxt(filename) 38 | x = torch.from_numpy(connectivity).float() 39 | 40 | adj = compute_KNN_graph(connectivity) 41 | adj = torch.from_numpy(adj).float() 42 | edge_index, edge_attr = dense_to_sparse(adj) 43 | 44 | data_list.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)) 45 | 46 | data, slices = self.collate(data_list) 47 | torch.save((data, slices), self.processed_paths[0]) 48 | -------------------------------------------------------------------------------- /RBFSVM_Kfold.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import pandas as pd 4 | from pathlib import Path 5 | 6 | from sklearn.svm import SVC 7 | from sklearn.model_selection import StratifiedKFold, GridSearchCV 8 | from sklearn.metrics import confusion_matrix 9 | 10 | 11 | dpath = './data/Main/raw' 12 | C_val = np.logspace(-4, 3, 8) 13 | gamma_val = np.logspace(-3, 2, 6) 14 | coefs_ = [] 15 | n_jobs = 5 16 | n_regions = 90 17 | 18 | print('Loading data ...') 19 | files = sorted(list(Path(dpath).glob("*.txt"))) 20 | features = [] 21 | 22 | for file in files: 23 | fc_mat = np.genfromtxt(file) 24 | fc_vec = np.concatenate([fc_mat[i][:i] for i in range(fc_mat.shape[0])]) 25 | features.append(fc_vec) 26 | features = np.array(features) 27 | labels = np.genfromtxt(osp.join(dpath, 'Labels.csv'), dtype='int32') 28 | 29 | skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=90) 30 | nested_skf = StratifiedKFold(n_splits=10, shuffle=True) 31 | eval_metrics = np.zeros((skf.n_splits, 3)) 32 | 33 | for n_fold, (train, test) in enumerate(skf.split(features, labels)): 34 | 35 | print('Processing the No.%i cross-validation in %i-fold CV' % (n_fold + 1, skf.n_splits)) 36 | x_train, y_train = features[train], labels[train] 37 | x_test, y_test = features[test], labels[test] 38 | 39 | init_clf = SVC(kernel='rbf') 40 | grid = GridSearchCV(init_clf, {'C': C_val, 'gamma': gamma_val}, cv=nested_skf, scoring='balanced_accuracy', 41 | n_jobs=n_jobs) 42 | grid.fit(x_train, y_train) 43 | print(' The best parameter C: %.2e and Gamma: %.2e with BAC of %f' 44 | % (grid.best_params_['C'], grid.best_params_['gamma'], grid.best_score_)) 45 | clf = SVC(kernel='rbf', C=grid.best_params_['C'], gamma=grid.best_params_['gamma']) 46 | clf.fit(x_train, y_train) 47 | y_pred = clf.predict(x_test) 48 | 49 | tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel() 50 | fold_sen = tp / (tp + fn) 51 | fold_spe = tn / (tn + fp) 52 | fold_bac = (fold_sen + fold_spe) / 2 53 | eval_metrics[n_fold, 0] = fold_sen 54 | eval_metrics[n_fold, 1] = fold_spe 55 | eval_metrics[n_fold, 2] = fold_bac 56 | 57 | eval_df = pd.DataFrame(eval_metrics) 58 | eval_df.columns = ['SEN', 'SPE', 'BAC'] 59 | eval_df.index = ['Fold_%02i' % (i + 1) for i in range(skf.n_splits)] 60 | print(eval_df) 61 | print('\nAverage Sensitivity: %.4f±%.4f' % (eval_metrics[:, 0].mean(), eval_metrics[:, 0].std())) 62 | print('Average Specificity: %.4f±%.4f' % (eval_metrics[:, 1].mean(), eval_metrics[:, 1].std())) 63 | print('Average Balanced Accuracy: %.4f±%.4f' % (eval_metrics[:, 2].mean(), eval_metrics[:, 2].std())) -------------------------------------------------------------------------------- /linearSVM_Loso.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import pandas as pd 4 | from pathlib import Path 5 | 6 | from sklearn.svm import SVC 7 | from sklearn.model_selection import LeaveOneGroupOut, StratifiedKFold, GridSearchCV 8 | from sklearn.metrics import confusion_matrix 9 | 10 | 11 | dpath = './data/Main/raw' 12 | C_val = np.logspace(-4, 3, 8) 13 | n_jobs = 5 14 | 15 | print('Loading data ...') 16 | files = sorted(list(Path(dpath).glob("*.txt"))) 17 | features = [] 18 | 19 | for file in files: 20 | fc_mat = np.genfromtxt(file) 21 | fc_vec = np.concatenate([fc_mat[i][:i] for i in range(fc_mat.shape[0])]) 22 | features.append(fc_vec) 23 | features = np.array(features) 24 | labels = np.genfromtxt(osp.join(dpath, 'Labels.csv'), dtype='int32') 25 | sites = np.genfromtxt(osp.join(dpath, 'sites.csv')) 26 | 27 | logo = LeaveOneGroupOut() 28 | nested_skf = StratifiedKFold(n_splits=10, shuffle=True) 29 | n_sites = np.unique(sites).shape[0] 30 | eval_metrics = np.zeros((n_sites - 1, 3)) # not testing on site 6 (HC site (Dataset 2 in manuscript)) 31 | 32 | for n_fold, (train_ind, test_ind) in enumerate(logo.split(features, labels, groups=sites)): 33 | 34 | if n_fold < 5: # not testing on site 6 (HC site (Dataset 2 in manuscript)) 35 | print('Processing the No.%i cross-validation in %i-fold CV' % (n_fold + 1, n_sites - 1)) 36 | x_train, y_train = features[train_ind], labels[train_ind] 37 | x_test, y_test = features[test_ind], labels[test_ind] 38 | 39 | init_clf = SVC(kernel='linear') 40 | grid = GridSearchCV(init_clf, {'C': C_val}, cv=nested_skf, scoring='balanced_accuracy', n_jobs=n_jobs) 41 | grid.fit(x_train, y_train) 42 | print(' The best parameter C: %.2e with BAC of %f' % (grid.best_params_['C'], grid.best_score_)) 43 | clf = SVC(kernel='linear', C=grid.best_params_['C']) 44 | clf.fit(x_train, y_train) 45 | y_pred = clf.predict(x_test) 46 | 47 | tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel() 48 | fold_sen = tp / (tp + fn) 49 | fold_spe = tn / (tn + fp) 50 | fold_bac = (fold_sen + fold_spe) / 2 51 | eval_metrics[n_fold, 0] = fold_sen 52 | eval_metrics[n_fold, 1] = fold_spe 53 | eval_metrics[n_fold, 2] = fold_bac 54 | 55 | eval_df = pd.DataFrame(eval_metrics) 56 | eval_df.columns = ['SEN', 'SPE', 'BAC'] 57 | eval_df.index = ['CV_' + str(i + 1) for i in range(n_sites - 1)] 58 | print(eval_df) 59 | print('\nAverage Sensitivity: %.4f±%.4f' % (eval_metrics[:, 0].mean(), eval_metrics[:, 0].std())) 60 | print('Average Specificity: %.4f±%.4f' % (eval_metrics[:, 1].mean(), eval_metrics[:, 1].std())) 61 | print('Average balanced accuracy: %.4f±%.4f' % (eval_metrics[:, 2].mean(), eval_metrics[:, 2].std())) 62 | -------------------------------------------------------------------------------- /RBFSVM_Loso.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import pandas as pd 4 | from pathlib import Path 5 | 6 | from sklearn.svm import SVC 7 | from sklearn.model_selection import LeaveOneGroupOut, StratifiedKFold, GridSearchCV 8 | from sklearn.metrics import confusion_matrix 9 | 10 | 11 | dpath = './data/Main/raw' 12 | C_val = np.logspace(-4, 3, 8) 13 | gamma_val = np.logspace(-3, 2, 6) 14 | n_jobs = 5 15 | n_regions = 90 16 | 17 | print('Loading data ...') 18 | files = sorted(list(Path(dpath).glob("*.txt"))) 19 | features = [] 20 | 21 | for file in files: 22 | fc_mat = np.genfromtxt(file) 23 | fc_vec = np.concatenate([fc_mat[i][:i] for i in range(fc_mat.shape[0])]) 24 | features.append(fc_vec) 25 | features = np.array(features) 26 | labels = np.genfromtxt(osp.join(dpath, 'Labels.csv'), dtype='int32') 27 | sites = np.genfromtxt(osp.join(dpath, 'sites.csv')) 28 | 29 | logo = LeaveOneGroupOut() 30 | nested_skf = StratifiedKFold(n_splits=10, shuffle=True) 31 | n_sites = np.unique(sites).shape[0] 32 | eval_metrics = np.zeros((n_sites - 1, 3)) # not testing on site 6 (HC site (Dataset 2 in manuscript)) 33 | 34 | for n_fold, (train_ind, test_ind) in enumerate(logo.split(features, labels, groups=sites)): 35 | 36 | if n_fold < 5: # not testing on site 6 (HC site (Dataset 2 in manuscript)) 37 | print('Processing the No.%i cross-validation in %i-fold CV' % (n_fold + 1, n_sites - 1)) 38 | x_train, y_train = features[train_ind], labels[train_ind] 39 | x_test, y_test = features[test_ind], labels[test_ind] 40 | 41 | init_clf = SVC(kernel='rbf') 42 | grid = GridSearchCV(init_clf, {'C': C_val, 'gamma': gamma_val}, cv=nested_skf, scoring='balanced_accuracy', 43 | n_jobs=n_jobs) 44 | grid.fit(x_train, y_train) 45 | print(' The best parameter C: %.2e and Gamma: %.2e with BAC of %f' % 46 | (grid.best_params_['C'], grid.best_params_['gamma'], grid.best_score_)) 47 | clf = SVC(kernel='rbf', C=grid.best_params_['C'], gamma=grid.best_params_['gamma']) 48 | clf.fit(x_train, y_train) 49 | y_pred = clf.predict(x_test) 50 | 51 | tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel() 52 | fold_sen = tp / (tp + fn) 53 | fold_spe = tn / (tn + fp) 54 | fold_bac = (fold_sen + fold_spe) / 2 55 | eval_metrics[n_fold, 0] = fold_sen 56 | eval_metrics[n_fold, 1] = fold_spe 57 | eval_metrics[n_fold, 2] = fold_bac 58 | 59 | eval_df = pd.DataFrame(eval_metrics) 60 | eval_df.columns = ['SEN', 'SPE', 'BAC'] 61 | eval_df.index = ['CV_' + str(i + 1) for i in range(n_sites - 1)] 62 | print(eval_df) 63 | print('\nAverage Sensitivity: %.4f±%.4f' % (eval_metrics[:, 0].mean(), eval_metrics[:, 0].std())) 64 | print('Average Specificity: %.4f±%.4f' % (eval_metrics[:, 1].mean(), eval_metrics[:, 1].std())) 65 | print('Average Balanced Accuracy: %.4f±%.4f' % (eval_metrics[:, 2].mean(), eval_metrics[:, 2].std())) 66 | -------------------------------------------------------------------------------- /linearSVM_Kfold.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import pandas as pd 4 | from pathlib import Path 5 | 6 | from sklearn.svm import SVC 7 | from sklearn.model_selection import StratifiedKFold, GridSearchCV 8 | from sklearn.metrics import confusion_matrix 9 | 10 | 11 | dpath = './data/Main/raw' 12 | C_val = np.logspace(-4, 3, 8) 13 | coefs_ = [] 14 | n_jobs = 5 15 | n_regions = 90 16 | 17 | print('Loading data ...') 18 | files = sorted(list(Path(dpath).glob("*.txt"))) 19 | features =[] 20 | 21 | for file in files: 22 | fc_mat = np.genfromtxt(file) 23 | fc_vec = np.concatenate([fc_mat[i][:i] for i in range(fc_mat.shape[0])]) 24 | features.append(fc_vec) 25 | features = np.array(features) 26 | labels = np.genfromtxt(osp.join(dpath, 'Labels.csv'), dtype='int32') 27 | 28 | skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=99) 29 | nested_skf = StratifiedKFold(n_splits=10, shuffle=True) 30 | eval_metrics = np.zeros((skf.n_splits, 3)) 31 | 32 | for n_fold, (train, test) in enumerate(skf.split(features, labels)): 33 | 34 | print('Processing the No.%i cross-validation in %i-fold CV' % (n_fold + 1, skf.n_splits)) 35 | x_train, y_train = features[train], labels[train] 36 | x_test, y_test = features[test], labels[test] 37 | 38 | init_clf = SVC(kernel='linear') 39 | grid = GridSearchCV(init_clf, {'C': C_val}, cv=nested_skf, scoring='balanced_accuracy', n_jobs=n_jobs) 40 | grid.fit(x_train, y_train) 41 | print(' The best parameter C: %.2e with BAC of %f' % (grid.best_params_['C'], grid.best_score_)) 42 | clf = SVC(kernel='linear', C=grid.best_params_['C']) 43 | clf.fit(x_train, y_train) 44 | y_pred = clf.predict(x_test) 45 | 46 | tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel() 47 | fold_sen = tp / (tp + fn) 48 | fold_spe = tn / (tn + fp) 49 | fold_bac = (fold_sen + fold_spe) / 2 50 | eval_metrics[n_fold, 0] = fold_sen 51 | eval_metrics[n_fold, 1] = fold_spe 52 | eval_metrics[n_fold, 2] = fold_bac 53 | 54 | weights = clf.coef_ 55 | C = np.zeros((n_regions, n_regions)) 56 | l = 0 57 | for a in range(1, n_regions): 58 | for b in range(a): 59 | C[a][b] = weights[0, l] 60 | l = l + 1 61 | C = C + C.T 62 | C = np.abs(C) 63 | coefs_.append(C) 64 | 65 | eval_df = pd.DataFrame(eval_metrics) 66 | eval_df.columns = ['SEN', 'SPE', 'BAC'] 67 | eval_df.index = ['Fold_%02i' % (i + 1) for i in range(skf.n_splits)] 68 | print(eval_df) 69 | print('\nAverage Sensitivity: %.4f±%.4f' % (eval_metrics[:, 0].mean(), eval_metrics[:, 0].std())) 70 | print('Average Specificity: %.4f±%.4f' % (eval_metrics[:, 1].mean(), eval_metrics[:, 1].std())) 71 | print('Average balanced accuracy: %.4f±%.4f' % (eval_metrics[:, 2].mean(), eval_metrics[:, 2].std())) 72 | 73 | mean_coefs_ = np.mean(np.array(coefs_), axis=0) 74 | reg_coefs_ = np.sum(mean_coefs_, axis=0) / (n_regions - 1) 75 | top_bid = np.argsort(-reg_coefs_)[:10] 76 | print(top_bid + 1) -------------------------------------------------------------------------------- /GCN_Kfold.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os.path as osp 4 | import warnings 5 | 6 | import torch 7 | import torch.nn.functional as func 8 | from torch_geometric.loader import DataLoader 9 | from sklearn.model_selection import StratifiedKFold, train_test_split 10 | from sklearn.metrics import confusion_matrix 11 | 12 | from Model import GCN 13 | from Dataset import ConnectivityData 14 | 15 | 16 | def GCN_train(loader): 17 | model.train() 18 | 19 | train_loss_all = 0 20 | for data in loader: 21 | data = data.to(device) 22 | optimizer.zero_grad() 23 | output = model(data) 24 | train_loss = func.cross_entropy(output, data.y) 25 | train_loss.backward() 26 | train_loss_all += data.num_graphs * train_loss.item() 27 | optimizer.step() 28 | return train_loss_all / len(train_dataset) 29 | 30 | 31 | def GCN_test(loader): 32 | model.eval() 33 | 34 | pred = [] 35 | label = [] 36 | val_loss_all = 0 37 | for data in loader: 38 | data = data.to(device) 39 | output = model(data) 40 | val_loss = func.cross_entropy(output, data.y) 41 | val_loss_all += data.num_graphs * val_loss.item() 42 | pred.append(func.softmax(output, dim=1).max(dim=1)[1]) 43 | label.append(data.y) 44 | 45 | y_pred = torch.cat(pred, dim=0).cpu().detach().numpy() 46 | y_true = torch.cat(label, dim=0).cpu().detach().numpy() 47 | tn, fp, fn, tp = confusion_matrix(y_pred, y_true).ravel() 48 | epoch_sen = tp / (tp + fn) 49 | epoch_spe = tn / (tn + fp) 50 | epoch_bac = (epoch_sen + epoch_spe) / 2 51 | return epoch_sen, epoch_spe, epoch_bac, val_loss_all / len(val_dataset) 52 | 53 | 54 | warnings.filterwarnings("ignore") 55 | dataset = ConnectivityData('./data_demo/Main') 56 | 57 | skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=99) 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | labels = np.genfromtxt(osp.join(dataset.raw_dir, 'Labels.csv')) 60 | eval_metrics = np.zeros((skf.n_splits, 3)) 61 | 62 | for n_fold, (train_val, test) in enumerate(skf.split(labels, labels)): 63 | 64 | model = GCN(dataset.num_features, dataset.num_classes, 6).to(device) 65 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) 66 | train_val_dataset, test_dataset = dataset[train_val.tolist()], dataset[test.tolist()] 67 | train_val_labels = labels[train_val] 68 | train_val_index = np.arange(len(train_val_dataset)) 69 | 70 | train, val, _, _ = train_test_split(train_val_index, train_val_labels, test_size=0.11, shuffle=True, stratify=train_val_labels) 71 | train_dataset, val_dataset = train_val_dataset[train.tolist()], train_val_dataset[val.tolist()] 72 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 73 | val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True) 74 | test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True) 75 | 76 | min_v_loss = np.inf 77 | for epoch in range(50): 78 | t_loss = GCN_train(train_loader) 79 | val_sen, val_spe, val_bac, v_loss = GCN_test(val_loader) 80 | test_sen, test_spe, test_bac, _ = GCN_test(test_loader) 81 | 82 | if min_v_loss > v_loss: 83 | min_v_loss = v_loss 84 | best_val_bac = val_bac 85 | best_test_sen, best_test_spe, best_test_bac = test_sen, test_spe, test_bac 86 | torch.save(model.state_dict(), 'best_model_%02i.pth' % (n_fold + 1)) 87 | print('CV: {:03d}, Epoch: {:03d}, Val Loss: {:.5f}, Val BAC: {:.5f}, Test BAC: {:.5f}, TEST SEN: {:.5f}, ' 88 | 'TEST SPE: {:.5f}'.format(n_fold + 1, epoch + 1, min_v_loss, best_val_bac, best_test_bac, 89 | best_test_sen, best_test_spe)) 90 | 91 | eval_metrics[n_fold, 0] = best_test_sen 92 | eval_metrics[n_fold, 1] = best_test_spe 93 | eval_metrics[n_fold, 2] = best_test_bac 94 | 95 | eval_df = pd.DataFrame(eval_metrics) 96 | eval_df.columns = ['SEN', 'SPE', 'BAC'] 97 | eval_df.index = ['Fold_%02i' % (i + 1) for i in range(skf.n_splits)] 98 | print(eval_df) 99 | print('Average Sensitivity: %.4f±%.4f' % (eval_metrics[:, 0].mean(), eval_metrics[:, 0].std())) 100 | print('Average Specificity: %.4f±%.4f' % (eval_metrics[:, 1].mean(), eval_metrics[:, 1].std())) 101 | print('Average Balanced Accuracy: %.4f±%.4f' % (eval_metrics[:, 2].mean(), eval_metrics[:, 2].std())) 102 | -------------------------------------------------------------------------------- /GCN_Loso.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os.path as osp 4 | import warnings 5 | 6 | import torch 7 | import torch.nn.functional as func 8 | from torch_geometric.loader import DataLoader 9 | from sklearn.model_selection import LeaveOneGroupOut, train_test_split 10 | from sklearn.metrics import confusion_matrix 11 | 12 | from Model import GCN 13 | from Dataset import ConnectivityData 14 | 15 | 16 | def GCN_train(loader): 17 | model.train() 18 | 19 | train_loss_all = 0 20 | for data in loader: 21 | data = data.to(device) 22 | optimizer.zero_grad() 23 | output = model(data) 24 | train_loss = func.cross_entropy(output, data.y) 25 | train_loss.backward() 26 | train_loss_all += data.num_graphs * train_loss.item() 27 | optimizer.step() 28 | return train_loss_all / len(train_dataset) 29 | 30 | 31 | def GCN_test(loader): 32 | model.eval() 33 | 34 | pred = [] 35 | label = [] 36 | val_loss_all = 0 37 | for data in loader: 38 | data = data.to(device) 39 | output = model(data) 40 | val_loss = func.cross_entropy(output, data.y) 41 | val_loss_all += data.num_graphs * val_loss.item() 42 | pred.append(func.softmax(output, dim=1).max(dim=1)[1]) 43 | label.append(data.y) 44 | 45 | y_pred = torch.cat(pred, dim=0).cpu().detach().numpy() 46 | y_true = torch.cat(label, dim=0).cpu().detach().numpy() 47 | tn, fp, fn, tp = confusion_matrix(y_pred, y_true).ravel() 48 | epoch_sen = tp / (tp + fn) 49 | epoch_spe = tn / (tn + fp) 50 | epoch_bac = (epoch_sen + epoch_spe) / 2 51 | return epoch_sen, epoch_spe, epoch_bac, val_loss_all / len(val_dataset) 52 | 53 | 54 | warnings.filterwarnings("ignore") 55 | dataset = ConnectivityData('./data_demo/Main') 56 | 57 | logo = LeaveOneGroupOut() 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | labels = np.genfromtxt(osp.join(dataset.raw_dir, 'Labels.csv')) 60 | sites = np.genfromtxt(osp.join(dataset.raw_dir, 'sites.csv')) 61 | n_sites = np.unique(sites).shape[0] 62 | eval_metrics = np.zeros((n_sites - 1, 3)) # not testing on site 6 (HC site (Dataset 2 in manuscript)) 63 | 64 | for n_fold, (train_val, test) in enumerate(logo.split(labels, labels, groups=sites)): 65 | 66 | if n_fold < 5: # not testing on Site 6 (HC site (Dataset 2 in manuscript)) 67 | model = GCN(dataset.num_features, dataset.num_classes, 6).to(device) 68 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) 69 | train_val_dataset, test_dataset = dataset[train_val.tolist()], dataset[test.tolist()] 70 | train_val_labels = labels[train_val] 71 | train_val_index = np.arange(len(train_val_dataset)) 72 | 73 | train, val, _, _ = train_test_split(train_val_index, train_val_labels, test_size=0.11, shuffle=True, stratify=train_val_labels) 74 | train_dataset, val_dataset = train_val_dataset[train.tolist()], train_val_dataset[val.tolist()] 75 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 76 | val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True) 77 | test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True) 78 | 79 | min_v_loss = np.inf 80 | for epoch in range(50): 81 | t_loss = GCN_train(train_loader) 82 | val_sen, val_spe, val_bac, v_loss = GCN_test(val_loader) 83 | test_sen, test_spe, test_bac, _ = GCN_test(test_loader) 84 | 85 | if min_v_loss > v_loss: 86 | min_v_loss = v_loss 87 | best_val_bac = val_bac 88 | best_test_sen, best_test_spe, best_test_bac = test_sen, test_spe, test_bac 89 | torch.save(model.state_dict(), 'best_model_%02i.pth' % (n_fold + 1)) 90 | print('CV: {:03d}, Epoch: {:03d}, Val Loss: {:.5f}, Val BAC: {:.5f}, Test BAC: {:.5f}, TEST SEN: {:.5f}, ' 91 | 'TEST SPE: {:.5f}'.format(n_fold + 1, epoch + 1, min_v_loss, best_val_bac, best_test_bac, 92 | best_test_sen, best_test_spe)) 93 | 94 | eval_metrics[n_fold, 0] = best_test_sen 95 | eval_metrics[n_fold, 1] = best_test_spe 96 | eval_metrics[n_fold, 2] = best_test_bac 97 | 98 | eval_df = pd.DataFrame(eval_metrics) 99 | eval_df.columns = ['SEN', 'SPE', 'BAC'] 100 | eval_df.index = ['Fold_%02i' % (i + 1) for i in range(n_sites - 1)] 101 | print(eval_df) 102 | print('Average Sensitivity: %.4f±%.4f' % (eval_metrics[:, 0].mean(), eval_metrics[:, 0].std())) 103 | print('Average Specificity: %.4f±%.4f' % (eval_metrics[:, 1].mean(), eval_metrics[:, 1].std())) 104 | print('Average Balanced Accuracy: %.4f±%.4f' % (eval_metrics[:, 2].mean(), eval_metrics[:, 2].std())) 105 | 106 | --------------------------------------------------------------------------------