├── src ├── models │ ├── __init__.py │ ├── MLP.py │ ├── BrainGNN.py │ ├── GCN.py │ └── GATv2.py ├── utils │ ├── __init__.py │ ├── modified_args.py │ ├── get_transform.py │ ├── attention.py │ ├── model_utils.py │ ├── simple_param.py │ ├── metrics.py │ ├── mixup.py │ ├── save_model.py │ ├── explain.py │ ├── train_and_evaluate.py │ ├── sample_selection.py │ ├── visualization.py │ ├── losses.py │ ├── SPD.py │ └── data_utils.py ├── dataset │ ├── __init__.py │ ├── base_transform.py │ ├── maskable_list.py │ ├── brain_data.py │ ├── private │ │ └── load_private.py │ ├── utils.py │ ├── brain_dataset.py │ └── transforms.py └── nni_configs │ ├── search_space.json │ └── config.yml ├── .gitignore ├── LICENSE ├── requirements.txt ├── community_networks ├── roi_names_mod.csv └── roi_coords.csv ├── README.md ├── nni_test.py └── main.py /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | pass 2 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .mixup import mixup, mixup_criterion 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | datasets/ 3 | outputs/ 4 | model_checkpoints/ 5 | *__pycache__/ -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .brain_data import BrainData 2 | from .brain_dataset import BrainDataset 3 | -------------------------------------------------------------------------------- /src/nni_configs/search_space.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_heads": { 3 | "_type": "choice", 4 | "_value": [2, 4, 6, 8, 10] 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /src/dataset/base_transform.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class BaseTransform(ABC): 5 | @abstractmethod 6 | def __call__(self, data): 7 | pass 8 | -------------------------------------------------------------------------------- /src/dataset/maskable_list.py: -------------------------------------------------------------------------------- 1 | from itertools import compress 2 | 3 | 4 | class MaskableList(list): 5 | def __getitem__(self, index): 6 | try: 7 | return super(MaskableList, self).__getitem__(index) 8 | except TypeError: 9 | return MaskableList(compress(self, index)) 10 | -------------------------------------------------------------------------------- /src/nni_configs/config.yml: -------------------------------------------------------------------------------- 1 | experimentName: xGWGAT 2 | searchSpaceFile: search_space.json 3 | trialCommand: python main.py --enable_nni 4 | trialCodeDirectory: ../../ 5 | trialGpuNumber: 1 6 | trialConcurrency: 4 7 | maxExperimentDuration: 2h 8 | maxTrialNumber: 20 9 | tuner: 10 | name: TPE 11 | classArgs: 12 | optimize_mode: maximize 13 | trainingService: 14 | platform: local 15 | useActiveGpu: True 16 | experimentWorkingDirectory: ../../nni-experiments/ -------------------------------------------------------------------------------- /src/utils/modified_args.py: -------------------------------------------------------------------------------- 1 | class ModifiedArgs(object): 2 | def __init__(self, name_space, update_dict): 3 | name_space_dict = vars(name_space) 4 | for key in name_space_dict: 5 | setattr(self, key, name_space_dict[key]) 6 | 7 | for key in update_dict: 8 | value = update_dict[key] 9 | try: 10 | value = int(value) 11 | except ValueError: 12 | try: 13 | value = float(value) 14 | except ValueError: 15 | pass 16 | setattr(self, key, value) 17 | -------------------------------------------------------------------------------- /src/dataset/brain_data.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data 2 | 3 | 4 | class BrainData(Data): 5 | def __init__(self, num_views=None, num_nodes=None, y=None, *args, **kwargs): 6 | super(BrainData, self).__init__() 7 | self.num_views = num_views 8 | self.num_nodes = num_nodes 9 | self.y = y 10 | for k, v in kwargs.items(): 11 | if ( 12 | k.startswith("x") 13 | or k.startswith("edge_index") 14 | or k.startswith("edge_attr") 15 | ): 16 | self.__dict__[k] = v 17 | 18 | def __inc__(self, key, value): 19 | if key.startswith("edge_index"): 20 | return self.num_nodes 21 | else: 22 | return super().__inc__(key, value) 23 | -------------------------------------------------------------------------------- /src/models/MLP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MLP(torch.nn.Module): 5 | def __init__(self, input_dim, hidden_dim, num_layers, activation, n_classes=0): 6 | super(MLP, self).__init__() 7 | self.net = [] 8 | self.net.append(torch.nn.Linear(input_dim, hidden_dim)) 9 | self.net.append(activation()) 10 | for _ in range(num_layers - 1): 11 | self.net.append(torch.nn.Linear(hidden_dim, hidden_dim)) 12 | self.net.append(activation()) 13 | self.net = torch.nn.Sequential(*self.net) 14 | self.shortcut = torch.nn.Linear(input_dim, hidden_dim) 15 | 16 | if n_classes != 0: 17 | self.classifier = torch.nn.Linear(hidden_dim, n_classes) 18 | 19 | def forward(self, x): 20 | out = self.net(x) + self.shortcut(x) 21 | if hasattr(self, "classifier"): 22 | return out, self.classifier(out) 23 | return 24 | -------------------------------------------------------------------------------- /src/utils/get_transform.py: -------------------------------------------------------------------------------- 1 | from src.dataset.base_transform import BaseTransform 2 | from src.dataset.transforms import * 3 | 4 | 5 | def get_transform(transform_type: str) -> BaseTransform: 6 | """ 7 | Maps transform_type to transform class 8 | :param transform_type: str 9 | :return: BaseTransform 10 | """ 11 | if transform_type == "identity": 12 | return Identity() 13 | elif transform_type == "degree": 14 | return Degree() 15 | elif transform_type == "degree_bin": 16 | return DegreeBin() 17 | elif transform_type == "LDP": 18 | return LDPTransform() 19 | elif transform_type == "adj": 20 | return Adj() 21 | elif transform_type == "node2vec": 22 | return Node2Vec() 23 | elif transform_type == "eigenvector": 24 | return Eigenvector() 25 | elif transform_type == "eigen_norm": 26 | return EigenNorm() 27 | else: 28 | raise ValueError("Unknown transform type: {}".format(transform_type)) 29 | -------------------------------------------------------------------------------- /src/models/BrainGNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | class BrainGNN(torch.nn.Module): 6 | """Adapted from https://github.com/HennyJie/BrainGB""" 7 | 8 | def __init__(self, gnn, mlp, args, discriminator=lambda x, y: x @ y.t()): 9 | super(BrainGNN, self).__init__() 10 | self.gnn = gnn 11 | self.mlp = mlp 12 | self.pooling = args.pooling 13 | self.discriminator = discriminator 14 | 15 | def forward(self, data, edge_index=None, edge_attr=None, batch=None): 16 | if edge_index is None: 17 | x, edge_index, edge_attr, batch = ( 18 | data.x, 19 | data.edge_index, 20 | data.edge_attr, 21 | data.batch, 22 | ) 23 | else: 24 | x = data.x 25 | g = self.gnn(data, edge_index, edge_attr, batch) 26 | 27 | if self.pooling == "concat": 28 | _, g = self.mlp(g) 29 | 30 | log_logits = F.log_softmax(g, dim=-1) 31 | 32 | return log_logits 33 | -------------------------------------------------------------------------------- /src/utils/attention.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import torch 3 | 4 | 5 | def save_attn_scores(attn, data): 6 | """ 7 | Save aggregated attention scores to file 8 | """ 9 | edge_idx = attn[0] 10 | alpha = attn[1] 11 | alpha = alpha[: edge_idx.size(1)] 12 | if alpha.dim() == 2: 13 | alpha_max = getattr(torch, "max")(alpha, dim=-1) 14 | if isinstance(alpha_max, tuple): 15 | alpha_max = alpha_max[0] 16 | 17 | G = nx.Graph() 18 | edge_idx_cpu = edge_idx.cpu().detach().numpy() 19 | edge_w_cpu = alpha_max.cpu().detach().numpy() 20 | for (u, v), w in zip(edge_idx_cpu.T, edge_w_cpu): 21 | G.add_edge(u, v, weight=w) 22 | 23 | # Save attention graph 24 | nx.write_graphml_lxml( 25 | G, f"results/explanations/graphs/attn_graphs_{int(data.s_ID.item())}.graphml" 26 | ) 27 | 28 | # Save attention coefficients 29 | attn_agg = (edge_idx, alpha_max) 30 | torch.save( 31 | attn_agg, 32 | f"results/explanations/attn_scores/attn_out_{int(data.s_ID.item())}.pt", 33 | ) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Favour Nerrise 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ase==3.22.1 2 | cached-property==1.5.2 3 | certifi==2023.7.22 4 | chardet==5.1.0 5 | cycler==0.11.0 6 | decorator 7 | future==0.18.3 8 | googledrivedownloader==0.4 9 | h5py==3.9.0 10 | idna==3.4 11 | imblearn 12 | isodate==0.6.1 13 | Jinja2==3.1.2 14 | joblib==1.3.1 15 | kiwisolver==1.4.4 16 | llvmlite==0.40.1 17 | MarkupSafe==2.1.3 18 | matplotlib==3.7.2 19 | networkx 20 | nilearn==0.10.1 21 | nni==2.10.1 22 | node2vec~=0.4.6 23 | numba 24 | numpy 25 | pandas==2.0.3 26 | Pillow==10.0.0 27 | pyg-lib -f https://data.pyg.org/whl/torch-2.0.1+cu118.html 28 | pymanopt==2.1.1 29 | pyparsing==3.0.0 30 | python-dateutil==2.8.2 31 | python-louvain==0.16 32 | pytz==2023.3 33 | rdflib==6.3.2 34 | requests==2.31.0 35 | scikit-learn==1.3.0 36 | scipy==1.11.1 37 | six==1.16.0 38 | threadpoolctl==3.2.0 39 | torch==2.0.1 #--index-url https://download.pytorch.org/whl/cu118 40 | 41 | #PyG 42 | -f https://data.pyg.org/whl/torch-2.0.1+cu118.html 43 | torch_geometric==2.3.1 44 | torch-cluster #--index-url https://pytorch-geometric.com/whl/torch-{2.0.1}+{cu118}.html 45 | torch-scatter #--index-url https://pytorch-geometric.com/whl/torch-{2.0.1}+{cu118}.html 46 | torch-sparse #--index-url https://pytorch-geometric.com/whl/torch-{2.0.1}+{cu118}.html 47 | torch-spline-conv #--index-url https://pytorch-geometric.com/whl/torch-{2.0.1}+{cu118}.html 48 | 49 | tqdm==4.65.0 50 | typing-extensions==4.7.1 51 | urllib3==2.0.4 52 | -------------------------------------------------------------------------------- /src/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.models.BrainGNN import BrainGNN 4 | from src.models.GATv2 import GATv2 5 | from src.models.GCN import GCN 6 | from src.models.MLP import MLP 7 | 8 | 9 | def build_model(args, num_features): 10 | """Build a classification model, e.g. GATv2, GCN, MLP 11 | 12 | Args: 13 | args (_type_): _description_ 14 | num_features (_type_): _description_ 15 | 16 | Raises: 17 | ValueError: if model not found 18 | 19 | Returns: 20 | nn.module: pyGmodel 21 | """ 22 | if args.model_name == "gatv2": 23 | model = BrainGNN( 24 | GATv2(num_features, args), 25 | MLP( 26 | args.num_classes, 27 | args.hidden_dim, 28 | args.n_MLP_layers, 29 | torch.nn.ReLU, 30 | n_classes=args.num_classes, 31 | ), 32 | args, 33 | ) 34 | elif args.model_name == "gcn": 35 | model = BrainGNN( 36 | GCN(num_features, args), 37 | MLP( 38 | args.num_classes, 39 | args.hidden_dim, 40 | args.n_MLP_layers, 41 | torch.nn.ReLU, 42 | n_classes=args.num_classes, 43 | ), 44 | args, 45 | ) 46 | else: 47 | raise ValueError(f'ERROR: Model name "{args.model_name}" not found!') 48 | return model 49 | -------------------------------------------------------------------------------- /src/utils/simple_param.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Optional 3 | 4 | import nni 5 | import yaml 6 | 7 | 8 | class SimpleParam: 9 | @staticmethod 10 | def _preprocess_nni(params: dict): 11 | return {k.split("/")[1]: v for k, v in params.items()} 12 | 13 | @staticmethod 14 | def _parse_yaml(path: str): 15 | content = open(path).read() 16 | return yaml.load(content, Loader=yaml.Loader) 17 | 18 | @staticmethod 19 | def _parse_json(path: str): 20 | content = open(path).read() 21 | return json.loads(content) 22 | 23 | def __init__(self, default: Optional[dict] = None): 24 | self.default = default if default is not None else dict() 25 | 26 | def __call__(self, from_: Optional[str] = "None", *args, **kwargs): 27 | if from_ == "nni": 28 | return {**self.default, **nni.get_next_parameter()} 29 | elif from_ != "None": 30 | if from_.endswith(".json"): 31 | loaded = self._parse_json(from_) 32 | elif from_.endswith(".yaml") or from_.endswith(".yml"): 33 | loaded = self._parse_yaml(from_) 34 | else: 35 | raise NotImplementedError 36 | 37 | if "preprocess_nni" in kwargs and kwargs["preprocess_nni"]: 38 | loaded = self._preprocess_nni(loaded) 39 | 40 | return {**self.default, **loaded} 41 | 42 | return self.default 43 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import metrics 3 | from sklearn.metrics import roc_auc_score 4 | from sklearn.preprocessing import LabelBinarizer 5 | 6 | 7 | def multiclass_roc_auc_score(y_test, y_pred, average="weighted"): 8 | lb = LabelBinarizer() 9 | lb.fit(y_test) 10 | 11 | y_test = lb.transform(y_test) 12 | y_pred = lb.transform(y_pred) 13 | 14 | return roc_auc_score(y_test, y_pred, average=average) 15 | 16 | 17 | def roc_auc_score_multiclass(y_test, y_pred, average="weighted"): 18 | # creating a set of all the unique classes using the actual class list 19 | unique_class = set(y_test) 20 | roc_auc_dict = {} 21 | for per_class in unique_class: 22 | # creating a list of all the classes except the current class 23 | other_class = [x for x in unique_class if x != per_class] 24 | 25 | # marking the current class as 1 and all other classes as 0 26 | new_actual_class = [0 if x in other_class else 1 for x in y_test] 27 | new_pred_class = [0 if x in other_class else 1 for x in y_pred] 28 | 29 | # using the sklearn metrics method to calculate the roc_auc_score 30 | try: 31 | roc_auc = metrics.roc_auc_score( 32 | new_actual_class, new_pred_class, average=average 33 | ) 34 | except ValueError: 35 | roc_auc = 0.5 36 | roc_auc_dict[per_class] = roc_auc 37 | 38 | # avg roc auc 39 | avg_roc_auc = sum(roc_auc_dict.values()) / len(unique_class) 40 | return (avg_roc_auc, roc_auc_dict) 41 | -------------------------------------------------------------------------------- /src/dataset/private/load_private.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn import preprocessing 4 | 5 | 6 | def load_data_private(root_dir, num_classes): 7 | if num_classes == 2: 8 | data = np.load(root_dir + "/private_binary.npy", allow_pickle=True) 9 | elif num_classes == 4: 10 | data = np.load(root_dir + "/private_multi.npy", allow_pickle=True) 11 | else: 12 | raise Exception( 13 | f"Invalid number of number of classes, " 14 | + "expected n_classes=2 or n_classes=4 but got {num_classes}." 15 | ) 16 | final_pearson = data["corr"] 17 | final_pearson = [x.astype(np.float32) for x in final_pearson] 18 | final_pearson = np.stack( 19 | final_pearson, axis=0 20 | ) # reshape to (n_subjects, num_nodes, num_nodes) 21 | labels = data["labels"].astype(np.int32) 22 | p_ids = data["p_ids"] 23 | encoder = preprocessing.LabelEncoder() 24 | p_IDs = encoder.fit_transform(p_ids) 25 | s_IDs = np.linspace(0, len(p_ids), len(p_ids), False) 26 | return final_pearson, labels, p_IDs, s_IDs 27 | 28 | 29 | def process_dataset(fc_data, fc_id, id2gender, id2pearson, label_df): 30 | final_label, final_pearson = [], [] 31 | for fc, l in zip(fc_data, fc_id): 32 | if l in id2gender and l in id2pearson: 33 | if not np.any(np.isnan(id2pearson[l])): 34 | final_label.append(id2gender[l]) 35 | final_pearson.append(id2pearson[l]) 36 | final_pearson = np.array(final_pearson) 37 | encoder = preprocessing.LabelEncoder() 38 | encoder.fit(label_df["sex"]) 39 | labels = encoder.transform(final_label) 40 | return final_pearson, labels 41 | -------------------------------------------------------------------------------- /src/utils/mixup.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def mixup_data(x, nodes, y, alpha=1.0, device="cuda"): 8 | """Returns mixed inputs, pairs of targets, and lambda""" 9 | if alpha > 0: 10 | lam = np.random.beta(alpha, alpha) 11 | else: 12 | lam = 1 13 | 14 | batch_size = x.size()[0] 15 | index = torch.randperm(batch_size).to(device) 16 | 17 | mixed_nodes = lam * nodes + (1 - lam) * nodes[index, :] 18 | mixed_x = lam * x + (1 - lam) * x[index, :] 19 | y_a, y_b = y, y[index] 20 | return mixed_x, mixed_nodes, y_a, y_b, lam 21 | 22 | 23 | def mixup_data_by_class(x, nodes, y, alpha=1.0, device="cuda"): 24 | """Returns mixed inputs, pairs of targets, and lambda""" 25 | 26 | mix_xs, mix_nodes, mix_ys = [], [], [] 27 | 28 | for t_y in y.unique(): 29 | idx = y == t_y 30 | 31 | t_mixed_x, t_mixed_nodes, _, _, _ = mixup_data( 32 | x[idx], nodes[idx], y[idx], alpha=alpha, device=device 33 | ) 34 | mix_xs.append(t_mixed_x) 35 | mix_nodes.append(t_mixed_nodes) 36 | 37 | mix_ys.append(y[idx]) 38 | 39 | return ( 40 | torch.cat(mix_xs, dim=0), 41 | torch.cat(mix_nodes, dim=0), 42 | torch.cat(mix_ys, dim=0), 43 | ) 44 | 45 | 46 | def mixup_criterion(criterion, pred, y_a, y_b, lam, weight): 47 | return lam * criterion(pred, y_a, weight) + (1 - lam) * criterion(pred, y_b, weight) 48 | 49 | 50 | def mixup(batch_data): 51 | x, edge_index, edge_attr, y, batch = ( 52 | batch_data.x, 53 | batch_data.edge_index, 54 | batch_data.edge_attr, 55 | batch_data.y, 56 | batch_data.batch, 57 | ) 58 | bz = torch.max(batch) + 1 59 | x = x.reshape((bz, -1, x.shape[-1])) 60 | edge_attr = edge_attr.reshape((bz, x.shape[1], -1)) 61 | 62 | mixed_x, mixed_nodes, y_a, y_b, lam = mixup_data(x, edge_attr, y) 63 | batch_data.x = mixed_x.reshape((-1, x.shape[-1])) 64 | batch_data.edge_attr = edge_attr.reshape((-1)) 65 | 66 | return batch_data, y_a, y_b, lam 67 | -------------------------------------------------------------------------------- /src/utils/save_model.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | 7 | plt.style.use("ggplot") 8 | 9 | 10 | class SaveBestModel: 11 | """ 12 | Class to save the best model while training. If the current epoch's 13 | validation loss is less than the previous least less, then save the 14 | model state. 15 | """ 16 | 17 | def __init__(self, best_valid_loss=float("inf")): 18 | self.best_valid_loss = best_valid_loss 19 | 20 | def __call__(self, current_valid_loss, epoch, model, optimizer, criterion, args): 21 | if current_valid_loss < self.best_valid_loss: 22 | self.best_valid_loss = current_valid_loss 23 | print(f"\nBest validation loss: {self.best_valid_loss}") 24 | print(f"\nSaving best model for epoch: {epoch+1}\n") 25 | torch.save( 26 | { 27 | "epoch": epoch + 1, 28 | "model_state_dict": model.state_dict(), 29 | "optimizer_state_dict": optimizer.state_dict(), 30 | "loss": criterion, 31 | }, 32 | f"model_checkpoints/best_model_{args.model_name}_{args.num_classes}.pth", 33 | ) 34 | 35 | 36 | def save_model(epochs, model, optimizer, args): 37 | """ 38 | Function to save the trained model to disk. 39 | """ 40 | print(f"Saving final model...") 41 | curr_dt = str(datetime.now()) 42 | 43 | torch.save( 44 | { 45 | "epoch": epochs, 46 | "model_state_dict": model.state_dict(), 47 | "optimizer_state_dict": optimizer.state_dict(), 48 | }, 49 | f"model_checkpoints/final_model_{args.model_name}_{args.num_classes}_{curr_dt}.pth", 50 | ) 51 | 52 | 53 | def save_plots(train_acc, valid_acc, train_loss, valid_loss): 54 | """ 55 | Function to save the loss and accuracy plots to disk. 56 | """ 57 | curr_dt = str(datetime.now()) 58 | # accuracy plots 59 | plt.figure(figsize=(10, 7)) 60 | plt.plot(train_acc, color="green", linestyle="-", label="train accuracy") 61 | plt.plot(valid_acc, color="blue", linestyle="-", label="validation accuracy") 62 | plt.xlabel("Epochs") 63 | plt.ylabel("Accuracy") 64 | plt.legend() 65 | plt.savefig(f"results/plots/accuracy_{curr_dt}.png") 66 | 67 | # loss plots 68 | plt.figure(figsize=(10, 7)) 69 | plt.plot(train_loss, color="orange", linestyle="-", label="train loss") 70 | plt.plot(valid_loss, color="red", linestyle="-", label="validation loss") 71 | plt.xlabel("Epochs") 72 | plt.ylabel("Loss") 73 | plt.legend() 74 | plt.savefig(f"results/plots/loss_{curr_dt}.png") 75 | -------------------------------------------------------------------------------- /src/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import torch 4 | from sklearn.preprocessing import OneHotEncoder 5 | from torch_geometric.data import Data 6 | from torch_geometric.utils import dense_to_sparse 7 | 8 | from .maskable_list import MaskableList 9 | 10 | 11 | def build_dataset(a1, args, y): 12 | x1 = compute_x(a1, args) 13 | data_list = MaskableList([]) 14 | for i in range(a1.shape[0]): 15 | edge_index, edge_attr = dense_to_sparse(a1[i]) 16 | data = Data( 17 | x=x1[i], edge_index=edge_index, edge_attr=edge_attr, y=y[i], adj=a1[i] 18 | ) 19 | data_list.append(data) 20 | return data_list 21 | 22 | 23 | def compute_x(a1: np.ndarray, node_feature: str): 24 | # construct node features X 25 | if node_feature == "identity": 26 | x = torch.cat([torch.diag(torch.ones(a1.shape[1]))] * a1.shape[0]).reshape( 27 | [a1.shape[0], a1.shape[1], -1] 28 | ) 29 | x1 = x.clone() 30 | 31 | # elif args.node_features == 'node2vec': 32 | # X = np.load(f'{path}/{args.dataset_name}_{args.modality}.emb', allow_pickle=True).astype(np.float32) 33 | # x1 = torch.from_numpy(X) 34 | 35 | elif node_feature == "degree": 36 | a1b = (a1 != 0).float() 37 | x1 = a1b.sum(dim=2, keepdim=True) 38 | 39 | elif node_feature == "degree_bin": 40 | a1b = (a1 != 0).float() 41 | x1 = binning(a1b.sum(dim=2)) 42 | 43 | elif node_feature == "adj": 44 | x1 = a1.float() 45 | 46 | elif node_feature == "LDP": 47 | a1b = (a1 != 0).float() 48 | x1 = [] 49 | n_graphs: int = a1.shape[0] 50 | for i in range(n_graphs): 51 | x1.append(LDP(nx.from_numpy_array(a1b[i].numpy()))) 52 | else: 53 | raise ValueError(f"Unknown node feature {node_feature}") 54 | x1 = torch.Tensor(x1).float() 55 | return x1 56 | 57 | 58 | # for LDP node features 59 | def LDP(g, key="deg"): 60 | x = np.zeros([len(g.nodes()), 5]) 61 | 62 | deg_dict = dict(nx.degree(g)) 63 | for n in g.nodes(): 64 | g.nodes[n][key] = deg_dict[n] 65 | 66 | for i in g.nodes(): 67 | nodes = g[i].keys() 68 | 69 | nbrs_deg = [g.nodes[j][key] for j in nodes] 70 | 71 | if len(nbrs_deg) != 0: 72 | x[i] = [ 73 | np.mean(nbrs_deg), 74 | np.min(nbrs_deg), 75 | np.max(nbrs_deg), 76 | np.std(nbrs_deg), 77 | np.sum(nbrs_deg), 78 | ] 79 | 80 | return x 81 | 82 | 83 | # for degree_bin node features 84 | def binning(a, n_bins=10): 85 | n_graphs = a.shape[0] 86 | n_nodes = a.shape[1] 87 | _, bins = np.histogram(a, n_bins) 88 | binned = np.digitize(a, bins) 89 | binned = binned.reshape(-1, 1) 90 | enc = OneHotEncoder() 91 | return ( 92 | enc.fit_transform(binned) 93 | .toarray() 94 | .reshape(n_graphs, n_nodes, -1) 95 | .astype(np.float32) 96 | ) 97 | -------------------------------------------------------------------------------- /community_networks/roi_names_mod.csv: -------------------------------------------------------------------------------- 1 | ROI,Color,Color 33,Color 8 2 | Grey Mat,0,1,0 3 | FP r,6,2,0 4 | FP l,6,3,0 5 | IC r,2,4,0 6 | IC l,2,4,0 7 | SFG r,6,2,0 8 | SFG l,6,3,0 9 | MidFG r,6,2,0 10 | MidFG l,6,3,0 11 | IFG tri r,3,5,0 12 | IFG tri l,7,6,0 13 | IFG oper r,7,7,0 14 | IFG oper l,7,6,0 15 | PreCG r,2,8,0 16 | PreCG l,2,8,0 17 | TP r,9,9,0 18 | TP l,9,9,0 19 | aSTG r,9,9,0 20 | aSTG l,9,9,0 21 | pSTG r,9,9,0 22 | pSTG l,9,9,0 23 | aMTG r,1,10,0 24 | aMTG l,1,10,0 25 | pMTG r,1,11,0 26 | pMTG l,1,11,0 27 | toMTG r,7,7,0 28 | toMTG l,7,6,0 29 | aITG r,3,5,0 30 | aITG l,3,5,0 31 | pITG r,3,5,0 32 | pITG l,3,5,0 33 | toITG r,5,12,0 34 | toITG l,5,12,0 35 | PostCG r,2,13,0 36 | PostCG l,2,13,0 37 | SPL r,5,12,0 38 | SPL l,5,12,0 39 | aSMG r,4,14,0 40 | aSMG l,4,14,0 41 | pSMG r,7,7,0 42 | pSMG l,7,6,0 43 | AG r,1,11,0 44 | AG l,1,11,0 45 | sLOC r,3,15,0 46 | sLOC l,3,15,0 47 | iLOC r,3,16,0 48 | iLOC l,3,16,0 49 | ICC r,3,16,0 50 | ICC l,3,16,0 51 | MedFC,1,10,0 52 | SMA r,2,4,0 53 | SMA l,2,4,0 54 | SubCalC,1,10,0 55 | PaCiG r,9,17,0 56 | PaCiG l,9,17,0 57 | AC,9,18,0 58 | PC,1,19,0 59 | Precuneous,1,19,0 60 | Cuneal r,3,16,0 61 | Cuneal l,3,16,0 62 | FOrb r,3,5,0 63 | FOrb l,3,5,0 64 | aPaHC r,9,20,0 65 | aPaHC l,9,20,0 66 | pPaHC r,9,21,0 67 | pPaHC l,9,21,0 68 | LG r,3,16,0 69 | LG l,3,16,0 70 | aTFusC r,9,22,0 71 | aTFusC l,9,22,0 72 | pTFusC r,9,22,0 73 | pTFusC l,9,22,0 74 | TOFusC r,3,16,0 75 | TOFusC l,3,16,0 76 | OFusG r,3,16,0 77 | OFusG l,3,16,0 78 | FO r,5,18,0 79 | FO l,5,18,0 80 | CO r,2,4,0 81 | CO l,2,4,0 82 | PO r,2,4,0 83 | PO l,2,4,0 84 | PP r,2,4,0 85 | PP l,2,4,0 86 | HG r,2,4,0 87 | HG l,2,4,0 88 | PT r,2,4,0 89 | PT l,2,4,0 90 | SCC r,3,16,0 91 | SCC l,3,16,0 92 | OP r,3,16,0 93 | OP l,3,16,0 94 | Thalamus r,9,23,0 95 | Thalamus l,9,23,0 96 | Caudate r,8,23,0 97 | Caudate l,8,23,0 98 | Putamen r,8,18,0 99 | Putamen l,8,18,0 100 | Pallidum r,1,18,0 101 | Pallidum l,1,18,0 102 | Hippocampus r,9,22,0 103 | Hippocampus l,9,22,0 104 | Amygdala r,9,22,0 105 | Amygdala l,9,22,0 106 | Accumbens r,8,23,0 107 | Accumbens l,8,23,0 108 | Brain-Stem,8,24,0 109 | Cereb1 l,8,25,0 110 | Cereb1 r,8,25,0 111 | Cereb2 l,8,25,0 112 | Cereb2 r,8,25,0 113 | Cereb3 l,8,24,0 114 | Cereb3 r,8,24,0 115 | Cereb45 l,8,24,0 116 | Cereb45 r,8,24,0 117 | Cereb6 l,8,24,0 118 | Cereb6 r,8,24,0 119 | Cereb7 l,8,25,0 120 | Cereb7 r,8,25,0 121 | Cereb8 l,8,24,0 122 | Cereb8 r,8,24,0 123 | Cereb9 l,8,24,0 124 | Cereb9 r,8,24,0 125 | Cereb10 l,8,24,0 126 | Cereb10 r,8,24,0 127 | Ver12,8,24,0 128 | Ver3,8,24,0 129 | Ver45,8,24,0 130 | Ver6,8,24,0 131 | Ver7,8,24,0 132 | Ver8,8,24,0 133 | Ver9,8,24,0 134 | Ver10,8,24,0 135 | DMN.MPFC,1,26,1 136 | DMN.LP l,1,26,1 137 | DMN.LP r,1,26,1 138 | DMN.PCC,1,26,1 139 | SMN.Lat l,2,27,2 140 | SMN.Lat r,2,27,2 141 | SMN.Sup,2,27,2 142 | VN.Med,3,28,3 143 | VN.Occ,3,28,3 144 | VN.Lat l,3,28,3 145 | VN.Lat r,3,28,3 146 | SN.ACC,4,29,4 147 | SN.AIns l,4,29,4 148 | SN.AIns r,4,29,4 149 | SN.RPFC l,4,29,4 150 | SN.RPFC r,4,29,4 151 | SN.SMG l,4,29,4 152 | SN.SMG r,4,29,4 153 | DAN.FEF l,5,30,5 154 | DAN.FEF r,5,30,5 155 | DAN.IPS l,5,30,5 156 | DAN.IPS r,5,30,5 157 | FP.LPFC l,6,31,6 158 | FP.PPC l,6,31,6 159 | FP.LPFC r,6,31,6 160 | FP.PPC r,6,31,6 161 | LN.IFG l,7,32,7 162 | LN.IFG r,7,32,7 163 | LN.pSTG l,7,32,7 164 | LN.pSTG r,7,32,7 165 | CN.Ant,8,33,8 166 | CN.Post,8,33,8 -------------------------------------------------------------------------------- /src/utils/explain.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from torch_geometric.explain import AttentionExplainer, Explainer 6 | 7 | from src.utils.data_utils import edge_index_to_adj_matrix 8 | from src.utils.visualization import to_brainnet, visualize_graph 9 | 10 | 11 | def get_top_k(mask, k=10): 12 | """ 13 | Threshold explanation edge masks to top-k attention weights/edges 14 | """ 15 | mask = torch.Tensor(mask) 16 | _, index = torch.topk( 17 | mask.flatten(), 18 | k=k, 19 | ) 20 | 21 | out = torch.zeros_like(mask.flatten()) 22 | out[index] = 1.0 23 | return out.view(mask.size()) 24 | 25 | 26 | def explain(model, dataloader, args): 27 | """ 28 | Generate and save explanation masks for each sample and per class 29 | """ 30 | datasets = {"0": [], "1": [], "2": [], "3": []} 31 | for data in dataloader: 32 | if str(data.y.detach().cpu().item()) in datasets: 33 | datasets[str(data.y.detach().cpu().item())].append(data) 34 | 35 | # Train explainer mask 36 | explainer_args = { 37 | "mode": "binary_classification" 38 | if args.num_classes == 2 39 | else "multiclass_classification" 40 | } 41 | explainer = Explainer( 42 | model=model, 43 | algorithm=AttentionExplainer(), 44 | explanation_type="model", 45 | edge_mask_type="object", 46 | model_config=dict( 47 | mode=explainer_args["mode"], 48 | task_level="graph", 49 | return_type="probs", 50 | ), 51 | ) 52 | 53 | explanation_adjs = [] 54 | explanation_edges = [] 55 | 56 | for i in range(args.num_classes): 57 | adjs = [] 58 | edges_all = [] 59 | for data in list(datasets.values())[i]: 60 | data = data.to(args.device) 61 | exp = explainer( 62 | data, data.edge_index, edge_attr=data.edge_attr, batch=data.batch 63 | ) 64 | G, edges, _ = visualize_graph( 65 | exp.edge_index, 66 | edge_attr=exp.edge_attr, 67 | x=exp.x, 68 | y=exp.y, 69 | threshold_num=10, 70 | ) 71 | 72 | nx.write_graphml_lxml( 73 | G, 74 | f"outputs/explanations/graphs/xGWGAT_exp_graph_{int(exp.x.s_ID.item())}.graphml", 75 | ) 76 | subgraph = exp.get_explanation_subgraph() 77 | 78 | mask = subgraph.edge_mask 79 | edges_all.append(edges) 80 | adj = edge_index_to_adj_matrix( 81 | subgraph.edge_index, subgraph.edge_attr, args.num_nodes 82 | ) 83 | adjs.append(adj) 84 | 85 | avg_edges = np.array(edges_all).mean(axis=0) 86 | explanation_edges.append(avg_edges) 87 | 88 | avg_adjs = np.array(adjs).mean(axis=0) 89 | masked_adj = get_top_k(avg_adjs, k=10) 90 | explanation_adjs.append(masked_adj.numpy()) 91 | 92 | explanation_adjs = np.array(explanation_adjs, dtype="object") 93 | explanation_edges = np.array(explanation_edges, dtype="object") 94 | 95 | roi_xyz = pd.read_csv( 96 | "community_networks/roi_coords.csv", index_col=False, header=None, skiprows=1 97 | ) 98 | roi_xyz = roi_xyz.drop(roi_xyz.columns[0], axis=1).T 99 | roi_labels = pd.read_csv("community_networks/roi_names_mod.csv")["ROI"] 100 | roi_colors = pd.read_csv("community_networks/roi_names_mod.csv")["Color"].to_numpy() 101 | 102 | # Generate .node and .edge files for visualization in BrainNet Viewer 103 | to_brainnet(edges, roi_xyz, roi_labels, C=roi_colors, prefix="top10") 104 | for i in range(len(explanation_edges)): 105 | to_brainnet( 106 | explanation_edges[i], roi_xyz, roi_labels, C=roi_colors, prefix=f"top10_{i}" 107 | ) 108 | 109 | return 110 | -------------------------------------------------------------------------------- /src/utils/train_and_evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import nni 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from sklearn import metrics 8 | 9 | from src.utils.metrics import multiclass_roc_auc_score 10 | from src.utils.save_model import SaveBestModel 11 | 12 | # Create logger 13 | logger = logging.getLogger("__name__") 14 | level = logging.INFO 15 | logger.setLevel(level) 16 | ch = logging.StreamHandler() 17 | ch.setLevel(level) 18 | logger.addHandler(ch) 19 | 20 | 21 | def train_eval( 22 | model, optimizer, scheduler, class_weights, args, train_loader, test_loader=None 23 | ): 24 | """ 25 | Train model 26 | """ 27 | model.train() 28 | save_best_model = SaveBestModel() # initialize SaveBestModel class 29 | criterion = nn.NLLLoss(weight=class_weights) 30 | 31 | train_preds, train_labels, train_aucs, train_accs = [], [], [], [] 32 | total_correct = 0 33 | total_samples = 0 34 | 35 | for i in range(args.epochs): 36 | running_loss = 0 # running loss for logging 37 | avg_train_losses = [] # average training loss per epoch 38 | 39 | for data in train_loader: 40 | data = data.to(args.device) 41 | optimizer.zero_grad() 42 | out = model(data) 43 | pred = out.max(dim=1)[1] # Get predicted labels 44 | train_preds.append(pred.detach().cpu().tolist()) 45 | train_labels.append(data.y.detach().cpu().tolist()) 46 | total_correct += int((pred == data.y).sum()) 47 | total_samples += data.y.size(0) # Increment the total number of samples 48 | 49 | loss = criterion(out, data.y) 50 | loss.backward() 51 | optimizer.step() 52 | 53 | running_loss += float(loss.item()) 54 | 55 | avg_train_loss = running_loss / len( 56 | train_loader.dataset 57 | ) # Correctly calculate loss per epoch 58 | avg_train_losses.append(avg_train_loss) 59 | 60 | train_acc, train_auc, _, _ = test(model, train_loader, args) 61 | 62 | logging.info( 63 | f"(Train) | Epoch={i+1:03d}/{args.epochs}, loss={avg_train_loss:.4f}, " 64 | + f"train_acc={(train_acc * 100):.2f}, " 65 | + f"train_auc={(train_auc * 100):.2f}" 66 | ) 67 | 68 | if (i + 1) % args.test_interval == 0: 69 | test_acc, test_auc, _, _ = test(model, test_loader, args) 70 | text = ( 71 | f"(Test) | Epoch {i}), test_acc={(test_acc * 100):.2f}, " 72 | f"test_auc={(test_auc * 100):.2f}\n" 73 | ) 74 | logging.info(text) 75 | 76 | if args.enable_nni: 77 | nni.report_intermediate_result(train_auc) 78 | 79 | if scheduler: 80 | scheduler.step(avg_train_loss) 81 | 82 | save_best_model(avg_train_loss, i, model, optimizer, criterion, args) 83 | 84 | train_accs, train_aucs = np.array(train_accs), np.array(train_aucs) 85 | return train_accs, train_aucs, model 86 | 87 | 88 | @torch.no_grad() 89 | def test(model, loader, args, test_loader=None): 90 | """ 91 | Test model 92 | """ 93 | model.eval() 94 | 95 | preds = [] 96 | # preds_prob = [] 97 | labels = [] 98 | test_aucs = [] 99 | 100 | for data in loader: 101 | data = data.to(args.device) 102 | out = model(data) 103 | 104 | pred = out.max(dim=1)[1] 105 | # preds_prob.append(torch.exp(out)[:, 1].detach().cpu().tolist()) 106 | preds.append(pred.detach().cpu().numpy().flatten()) 107 | labels.append(data.y.detach().cpu().numpy().flatten()) 108 | 109 | labels = np.array(labels).ravel() 110 | preds = np.array(preds).ravel() 111 | 112 | if args.num_classes > 2: 113 | try: 114 | # Compute the ROC AUC score. 115 | t_auc = multiclass_roc_auc_score(labels, preds) 116 | except ValueError as err: 117 | # Handle the exception. 118 | print(f"Warning: {err}") 119 | t_auc = 0.5 120 | else: 121 | t_auc = metrics.roc_auc_score(labels, preds, average="weighted") 122 | 123 | test_aucs.append(t_auc) 124 | 125 | if test_loader is not None: 126 | _, test_auc, preds, labels = test(model, test_loader, args) 127 | test_acc = np.mean(np.array(preds) == np.array(labels)) 128 | 129 | return test_auc, test_acc 130 | else: 131 | t_acc = np.mean(np.array(preds) == np.array(labels)) 132 | return t_acc, t_auc, preds, labels 133 | -------------------------------------------------------------------------------- /src/dataset/brain_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path as osp 3 | import sys 4 | 5 | import numpy as np 6 | import torch 7 | from scipy.io import loadmat 8 | from torch_geometric.data import Data, InMemoryDataset 9 | from torch_geometric.data.dataset import files_exist 10 | from torch_geometric.data.makedirs import makedirs 11 | 12 | from .base_transform import BaseTransform 13 | from .private.load_private import load_data_private 14 | 15 | 16 | def dense_to_ind_val(adj): 17 | assert adj.dim() >= 2 and adj.dim() <= 3 18 | assert adj.size(-1) == adj.size(-2) 19 | 20 | index = (torch.isnan(adj) == 0).nonzero(as_tuple=True) 21 | edge_attr = adj[index] 22 | 23 | return torch.stack(index, dim=0), edge_attr 24 | 25 | 26 | class BrainDataset(InMemoryDataset): 27 | def __init__( 28 | self, 29 | root, 30 | name, 31 | num_classes, 32 | transform=None, 33 | pre_transform: BaseTransform = None, 34 | view=0, 35 | ): 36 | self.view: int = view 37 | self.name = name.upper() 38 | self.n_classes = num_classes 39 | self.filename_postfix = ( 40 | str(pre_transform) if pre_transform is not None else None 41 | ) 42 | print(self.name) 43 | assert self.name in ["PPMI", "PRIVATE"] 44 | super(BrainDataset, self).__init__(root, transform, pre_transform) 45 | self.data, self.slices, self.num_nodes = torch.load(self.processed_paths[0]) 46 | logging.info("Loaded dataset: {}".format(self.name)) 47 | 48 | @property 49 | def raw_dir(self): 50 | return self.root 51 | 52 | @property 53 | def processed_dir(self): 54 | return osp.join(self.root, "processed") 55 | 56 | @property 57 | def raw_file_names(self): 58 | return f"{self.name}.mat" 59 | 60 | @property 61 | def processed_file_names(self): 62 | name = f"{self.name}_{self.view}_{self.n_classes}" 63 | if self.filename_postfix is not None: 64 | name += f"_{self.filename_postfix}" 65 | return f"{name}.pt" 66 | 67 | def _download(self): 68 | if files_exist(self.raw_paths) or self.name in ["PRIVATE"]: # pragma: no cover 69 | return 70 | 71 | makedirs(self.raw_dir) 72 | self.download() 73 | 74 | def download(self): 75 | raise NotImplementedError 76 | 77 | def process(self): 78 | if self.name == "PRIVATE": 79 | adj, y, p_IDs, s_IDs = load_data_private(self.raw_dir, self.n_classes) 80 | y = torch.LongTensor(y) 81 | adj = torch.Tensor(np.array(adj)) 82 | p_IDs = torch.Tensor(p_IDs) 83 | s_IDs = torch.Tensor(s_IDs) 84 | num_graphs = adj.shape[0] 85 | num_nodes = adj.shape[1] 86 | else: 87 | m = loadmat(osp.join(self.raw_dir, self.raw_file_names)) 88 | if self.name == "PPMI": 89 | if self.view > 2 or self.view < 0: 90 | raise ValueError(f"{self.name} only has 3 views") 91 | raw_data = m["X"] 92 | num_graphs = raw_data.shape[0] 93 | num_nodes = raw_data[0][0].shape[0] 94 | a = np.zeros((num_graphs, num_nodes, num_nodes)) 95 | for i, sample in enumerate(raw_data): 96 | a[i, :, :] = sample[0][:, :, self.view] 97 | adj = torch.Tensor(a) 98 | else: 99 | key = "fmri" if self.view == 1 else "dti" 100 | adj = torch.Tensor(m[key]).transpose(0, 2) 101 | num_graphs = adj.shape[0] 102 | num_nodes = adj.shape[1] 103 | 104 | y = torch.Tensor(m["label"]).long().flatten() 105 | y[y == -1] = 0 106 | 107 | data_list = [] 108 | for i in range(num_graphs): 109 | edge_index, edge_attr = dense_to_ind_val(adj[i]) 110 | data = Data( 111 | x=adj[i], 112 | num_nodes=num_nodes, 113 | y=y[i], 114 | edge_index=edge_index, 115 | edge_attr=edge_attr, 116 | p_ID=p_IDs[i], 117 | s_ID=s_IDs[i], 118 | ) 119 | data_list.append(data) 120 | 121 | if self.pre_filter is not None: 122 | data_list = [data for data in data_list if self.pre_filter(data)] 123 | 124 | if self.pre_transform is not None: 125 | data_list = [self.pre_transform(data) for data in data_list] 126 | 127 | data, slices = self.collate(data_list) 128 | torch.save((data, slices, num_nodes), self.processed_paths[0]) 129 | 130 | def _process(self): 131 | print("Processing...", file=sys.stderr) 132 | 133 | if files_exist(self.processed_paths): # pragma: no cover 134 | print("File exists...Done!", file=sys.stderr) 135 | return 136 | 137 | makedirs(self.processed_dir) 138 | self.process() 139 | 140 | print("Done!", file=sys.stderr) 141 | 142 | def __repr__(self) -> str: 143 | return f"{self.__class__.__name__}{self.name}()" 144 | -------------------------------------------------------------------------------- /src/utils/sample_selection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample selection function. 3 | """ 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | from sklearn.linear_model import LinearRegression 8 | from sklearn.model_selection import KFold 9 | 10 | 11 | def get_class_id(idx, class_labels): 12 | """ 13 | Get the class ID for a given sample index. 14 | 15 | Parameters: 16 | idx (int): The index of the sample for which to get the class ID. 17 | class_labels (list or array): A list or array containing the class labels of each sample. 18 | 19 | Returns: 20 | class_id: The class ID of the sample at the specified index 'idx'. 21 | """ 22 | return class_labels[idx] 23 | 24 | 25 | def select_samples( 26 | train_idx, n_splits, k_list, data_dict, score_dict, y=None, shuffle=True, rs=None 27 | ): 28 | """Using the provided data and score dictionaries, 29 | selects the most important samples 30 | """ 31 | freq_dict = {k: defaultdict(int) for k in k_list} 32 | for fold, (train_ids, holdout_ids) in enumerate( 33 | KFold(n_splits=n_splits, shuffle=shuffle, random_state=rs).split(train_idx) 34 | ): 35 | residuals = [] 36 | fiq_errors = [] 37 | 38 | for i in range(len(train_ids)): 39 | for j in range(i + 1, len(train_ids)): 40 | distance = data_dict[train_idx[train_ids[i]], train_idx[train_ids[j]]] 41 | fiq_error = score_dict[train_idx[train_ids[i]], train_idx[train_ids[j]]] 42 | residuals.append(distance) 43 | fiq_errors.append(fiq_error) 44 | 45 | if np.array(residuals[0]).shape == (): 46 | residuals = np.array(residuals).reshape(-1, 1) 47 | 48 | selectionNet = LinearRegression() 49 | selectionNet.fit(residuals, fiq_errors) 50 | 51 | for i in range(holdout_ids.shape[0]): 52 | R_tst = [] 53 | for j in range(train_ids.shape[0]): 54 | distance = data_dict[train_idx[holdout_ids[i]], train_idx[train_ids[j]]] 55 | R_tst.append(distance) 56 | R_tst = np.stack(R_tst) 57 | if np.array(R_tst[0]).shape == (): 58 | R_tst = np.array(R_tst).reshape(-1, 1) 59 | error_pred = selectionNet.predict(R_tst).ravel() 60 | for k in k_list: 61 | for id in train_idx[train_ids[np.argsort(error_pred)[:k]]]: 62 | freq_dict[k][id] += 1 63 | important_samples = {k: [] for k in k_list} 64 | for k, fd in freq_dict.items(): 65 | ids, freqs = np.array(list(fd.keys())), np.array(list(fd.values())) 66 | important_samples[k] = ids[np.argsort(freqs)[::-1][:k]] 67 | return important_samples 68 | 69 | 70 | from collections import defaultdict 71 | 72 | 73 | def select_samples_per_class( 74 | train_idx, 75 | n_splits, 76 | k_list, 77 | data_dict, 78 | score_dict, 79 | class_labels, 80 | shuffle=True, 81 | rs=None, 82 | ): 83 | """Using the provided data and score dictionaries, 84 | selects the most important samples per class 85 | """ 86 | class_dict = defaultdict(lambda: defaultdict(list)) 87 | for idx in train_idx: 88 | class_id = get_class_id(idx, class_labels) # Get the class ID of a sample 89 | class_dict[class_id]["samples"].append(idx) 90 | 91 | important_samples_per_class = defaultdict(lambda: {k: [] for k in k_list}) 92 | 93 | for class_id, class_info in class_dict.items(): 94 | class_samples = class_info["samples"] 95 | n_samples = len(class_samples) 96 | 97 | if n_splits > n_samples: 98 | n_splits = n_samples 99 | 100 | freq_dict = {k: defaultdict(int) for k in k_list} 101 | 102 | for fold, (train_ids, holdout_ids) in enumerate( 103 | KFold(n_splits=n_splits, shuffle=shuffle, random_state=rs).split( 104 | class_samples 105 | ) 106 | ): 107 | residuals = [] 108 | score_errors = [] 109 | 110 | for i, train_id_i in enumerate(train_ids): 111 | if ( 112 | len(train_ids) == 1 113 | ): # when only 1 sample per training fold due to few class samples 114 | residuals.append( 115 | data_dict[ 116 | class_samples[train_id_i], class_samples[train_id_i] + 1 117 | ] 118 | ) 119 | score_errors.append( 120 | score_dict[ 121 | class_samples[train_id_i], class_samples[train_id_i] + 1 122 | ] 123 | ) 124 | continue 125 | 126 | for j, train_id_j in enumerate(train_ids[i + 1 :], start=i + 1): 127 | distance = data_dict[ 128 | class_samples[train_id_i], class_samples[train_id_j] 129 | ] 130 | score_error = score_dict[ 131 | class_samples[train_id_i], class_samples[train_id_j] 132 | ] 133 | residuals.append(distance) 134 | score_errors.append(score_error) 135 | 136 | if np.array(residuals[0]).shape == (): 137 | residuals = np.array(residuals).reshape(-1, 1) 138 | 139 | selectionNet = LinearRegression() 140 | selectionNet.fit(residuals, score_errors) 141 | 142 | for i in range(holdout_ids.shape[0]): 143 | R_tst = [] 144 | for j in range(train_ids.shape[0]): 145 | distance = data_dict[ 146 | class_samples[holdout_ids[i]], class_samples[train_ids[j]] 147 | ] 148 | R_tst.append(distance) 149 | R_tst = np.stack(R_tst) 150 | if np.array(R_tst[0]).shape == (): 151 | R_tst = np.array(R_tst).reshape(-1, 1) 152 | error_pred = selectionNet.predict(R_tst).ravel() 153 | 154 | for k in k_list: 155 | indices_sorted = np.argsort(error_pred)[:k] 156 | ids_selected = [ 157 | class_samples[train_ids[idx]] for idx in indices_sorted 158 | ] 159 | for id in ids_selected: 160 | freq_dict[k][id] += 1 161 | 162 | for k, fd in freq_dict.items(): 163 | ids, freqs = np.array(list(fd.keys())), np.array(list(fd.values())) 164 | indices_sorted = np.argsort(freqs)[::-1][ 165 | :k 166 | ] # Sort in descending order to get top k 167 | important_samples_per_class[class_id][k] = ids[indices_sorted] 168 | 169 | return important_samples_per_class 170 | -------------------------------------------------------------------------------- /src/utils/visualization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import torch 6 | from torch import Tensor 7 | from torch_geometric.data import Data 8 | from torch_geometric.utils import to_networkx 9 | 10 | from src.utils.data_utils import * 11 | 12 | 13 | def visualize_graph( 14 | edge_index, 15 | edge_attr: Optional[Tensor], 16 | node_atts=None, 17 | x=None, 18 | y: Optional[torch.FloatTensor] = None, 19 | threshold_num=None, 20 | ): 21 | # Adapted from https://github.com/HennyJie/IBGNN/ 22 | r"""Visualizes the graph given an edge mask 23 | :attr:`edge_mask`. 24 | Args: 25 | edge_index (LongTensor): The edge indices. 26 | edge_mask (Tensor): The edge mask. 27 | y (Tensor, optional): The ground-truth node-prediction labels used 28 | as node colorings. (default: :obj:`None`) 29 | threshold (float, optional): Sets a threshold for visualizing 30 | important edges. If set to :obj:`None`, will visualize all 31 | edges with transparancy indicating the importance of edges. 32 | (default: :obj:`None`) 33 | **kwargs (optional): Additional arguments passed to 34 | :func:`nx.draw`. 35 | :rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph` 36 | """ 37 | if edge_attr is not None: 38 | assert edge_attr.size(0) == edge_index.size(1) 39 | 40 | subset = torch.arange(edge_index.max().item() + 1, device=edge_index.device) 41 | 42 | if edge_attr is None: 43 | edge_attr = torch.ones(edge_index.size(1), device=edge_index.device) 44 | 45 | if y is None: 46 | y = torch.zeros(edge_index.max().item() + 1, device=edge_index.device) 47 | else: 48 | y = y.cpu() 49 | y = y[subset].to(torch.float) / y.max().item() 50 | 51 | data = Data(edge_index=edge_index, att=edge_attr, x=x, y=y, num_nodes=y.size(0)).to( 52 | "cpu" 53 | ) 54 | G = to_networkx(data, node_attrs=["y"], edge_attrs=["att"]) 55 | mapping = {k: i for k, i in enumerate(subset.tolist())} 56 | G = nx.relabel_nodes(G, mapping) 57 | 58 | att_array = np.array([data["att"] for _, _, data in G.edges(data=True)]) 59 | min_att, max_att = np.amin(att_array), np.amax(att_array) 60 | # reward = (max_att - min_att) / 10 61 | # att_array = self.reward_edge_postprocessing(att_array, edge_index, reward) 62 | # range_att = max_att - min_att 63 | # if range_att == 0: 64 | # range_att = max_att 65 | graph_nodes = G.nodes 66 | 67 | edges = edge_index_to_adj_matrix(edge_index, edge_attr, y.shape[0]) 68 | 69 | unfiltered_edges = edges.copy() 70 | if threshold_num is not None: 71 | edges = denoise_graph(edges, 0, threshold_num=threshold_num) 72 | 73 | return G, edges, unfiltered_edges 74 | 75 | 76 | def denoise_graph( 77 | adj, 78 | node_idx, 79 | feat=None, 80 | label=None, 81 | threshold=None, 82 | threshold_num=None, 83 | max_component=True, 84 | ): 85 | """Cleaning a graph by thresholding its node values. 86 | Args: 87 | - adj : Adjacency matrix. 88 | - node_idx : Index of node to highlight (TODO What is this used for??) 89 | - feat : An array of node features. 90 | - label : A list of node labels. 91 | - threshold : The weight threshold. 92 | - theshold_num : The maximum number of nodes to threshold. 93 | - max_component : TODO Looks like this has already been implemented 94 | """ 95 | num_nodes = adj.shape[-1] 96 | G = nx.Graph() 97 | G.add_nodes_from(range(num_nodes)) 98 | G.nodes[node_idx]["self"] = 1 99 | if feat is not None: 100 | for node in G.nodes(): 101 | G.nodes[node]["feat"] = feat[node] 102 | if label is not None: 103 | for node in G.nodes(): 104 | G.nodes[node]["label"] = label[node] 105 | 106 | if threshold_num is not None: 107 | # this is for symmetric graphs: edges are repeated twice in adj 108 | adj_threshold_num = threshold_num * 2 109 | # adj += np.random.rand(adj.shape[0], adj.shape[1]) * 1e-4 110 | neigh_size = len(adj[adj > 0]) 111 | threshold_num = min(neigh_size, adj_threshold_num) 112 | threshold = np.sort(adj[adj > 0])[-threshold_num] 113 | 114 | if threshold is not None: 115 | weighted_edge_list = [ 116 | (i, j, adj[i, j] if adj[i, j] >= threshold else 0) 117 | for i in range(num_nodes) 118 | for j in range(num_nodes) 119 | ] 120 | else: 121 | weighted_edge_list = [ 122 | (i, j, adj[i, j]) 123 | for i in range(num_nodes) 124 | for j in range(num_nodes) 125 | if adj[i, j] > 1e-6 126 | ] 127 | G.add_weighted_edges_from(weighted_edge_list) 128 | 129 | for i in range(num_nodes): 130 | for j in range(num_nodes): 131 | adj[i][j] = weighted_edge_list[i * num_nodes + j][2] 132 | return adj 133 | 134 | 135 | def to_brainnet( 136 | edges, 137 | roi_xyz, 138 | roi_labels, 139 | C=None, 140 | S=None, 141 | path="outputs/explanations/bnv", 142 | prefix="bnv", 143 | ): 144 | """Export data to plaintext file(s) for use with BrainNet Viewer 145 | [1]. For details regarding .node and .edge file construction, the 146 | user is directed to the BrainNet Viewer User Manual. 147 | This code was quality tested using BrainNet version 1.61 released on 148 | 2017-10-31 with MATLAB 9.3.0.713579 (R2017b). 149 | Parameters: 150 | ----------- 151 | edges : numpy array 152 | N x N matrix containing edge weights. 153 | roi_xyz : pandas dataframe 154 | N x 3 dataframe containing the (x, y, z) MNI coordinates of each 155 | brain ROI. 156 | roi_names : pandas series 157 | Names of each ROI as string. 158 | C : pandas series 159 | Node color value (defaults to same color). For modular color, 160 | use integers; for continuous data use floats. 161 | S : pandas series 162 | Node size value (defaults to same size). 163 | path : string 164 | Path to output directory (default is current directory). Note: 165 | do not include trailing '/' at end. 166 | prefix : string 167 | Filename prefix for output files. 168 | Returns 169 | ------- 170 | .node, .edge : files 171 | Plaintext output files for input to BrainNet. 172 | References 173 | ---------- 174 | [1] Xia M, Wang J, He Y (2013) BrainNet Viewer: A Network 175 | Visualization Tool for Human Brain Connectomics. PLoS ONE 8: 176 | e68910. 177 | """ 178 | 179 | N = len(roi_xyz) # number of nodes 180 | 181 | if C is None: 182 | C = np.ones(N) 183 | 184 | if S is None: 185 | S = np.ones(N) 186 | 187 | # BrainNet does not recognize node labels with white space, replace 188 | # spaces with underscore 189 | labels = roi_labels.str.replace(" ", "_").to_list() 190 | 191 | # Build .node dataframe 192 | df = roi_xyz.copy() 193 | df = df.assign(C=C).assign(S=S).assign(labels=labels) 194 | 195 | # Output .node file 196 | df.to_csv(f"{path}/{prefix}.node", sep="\t", header=False, index=False) 197 | print(f"Saved {path}/{prefix}.node") 198 | 199 | # Output .edge file 200 | np.savetxt(f"{path}/{prefix}.edge", edges, delimiter="\t") 201 | print(f"Saved {path}/{prefix}.edge") 202 | -------------------------------------------------------------------------------- /src/models/GCN.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch_geometric.nn import GCNConv, global_add_pool, global_mean_pool 6 | from torch_geometric.nn.inits import zeros 7 | 8 | 9 | class MPGCNConv(GCNConv): 10 | def __init__( 11 | self, 12 | in_channels, 13 | out_channels, 14 | edge_emb_dim: int, 15 | gcn_mp_type: str, 16 | bucket_sz: float, 17 | normalize: bool = True, 18 | bias: bool = True, 19 | ): 20 | super(MPGCNConv, self).__init__( 21 | in_channels=in_channels, out_channels=out_channels, aggr="add" 22 | ) 23 | 24 | self.edge_emb_dim = edge_emb_dim 25 | self.gcn_mp_type = gcn_mp_type 26 | self.bucket_sz = bucket_sz 27 | self.bucket_num = math.ceil(2.0 / self.bucket_sz) 28 | if gcn_mp_type == "bin_concat": 29 | self.edge2vec = nn.Embedding(self.bucket_num, edge_emb_dim) 30 | 31 | self.normalize = normalize 32 | self._cached_edge_index = None 33 | self._cached_adj_t = None 34 | 35 | input_dim = out_channels 36 | if gcn_mp_type == "bin_concat" or gcn_mp_type == "edge_weight_concat": 37 | input_dim = out_channels + edge_emb_dim 38 | elif gcn_mp_type == "edge_node_concat": 39 | input_dim = out_channels * 2 + 1 40 | elif gcn_mp_type == "node_concat": 41 | input_dim = out_channels * 2 42 | self.edge_lin = torch.nn.Linear(input_dim, out_channels) 43 | 44 | self.reset_parameters() 45 | 46 | def reset_parameters(self): 47 | zeros(self.bias) 48 | self._cached_edge_index = None 49 | self._cached_adj_t = None 50 | 51 | def message(self, x_i, x_j, edge_weight): 52 | # x_j: [E, in_channels] 53 | if self.gcn_mp_type == "weighted_sum": 54 | # use edge_weight as multiplier 55 | msg = edge_weight.view(-1, 1) * x_j 56 | elif self.gcn_mp_type == "bin_concat": 57 | # concat xj and learned bin embedding 58 | bucket = torch.div( 59 | edge_weight + 1, self.bucket_sz, rounding_mode="trunc" 60 | ).int() 61 | msg = torch.cat([x_j, self.edge2vec(bucket)], dim=1) 62 | msg = self.edge_lin(msg) 63 | elif self.gcn_mp_type == "edge_weight_concat": 64 | # concat xj and tiled edge attr 65 | msg = torch.cat( 66 | [x_j, edge_weight.view(-1, 1).repeat(1, self.edge_emb_dim)], dim=1 67 | ) 68 | msg = self.edge_lin(msg) 69 | elif self.gcn_mp_type == "edge_node_concat": 70 | # concat xi, xj and edge_weight 71 | msg = torch.cat([x_i, x_j, edge_weight.view(-1, 1)], dim=1) 72 | msg = self.edge_lin(msg) 73 | elif self.gcn_mp_type == "node_concat": 74 | # concat xi and xj 75 | msg = torch.cat([x_i, x_j], dim=1) 76 | msg = self.edge_lin(msg) 77 | else: 78 | raise ValueError(f"Invalid message passing variant {self.gcn_mp_type}") 79 | return msg 80 | 81 | 82 | class GCN(torch.nn.Module): 83 | def __init__(self, input_dim, args): 84 | super(GCN, self).__init__() 85 | self.activation = torch.nn.ReLU() 86 | self.convs = torch.nn.ModuleList() 87 | self.pooling = args.pooling 88 | self.num_nodes = args.num_nodes 89 | self.num_classes = args.num_classes 90 | 91 | hidden_dim = args.hidden_dim 92 | num_layers = args.n_GNN_layers 93 | edge_emb_dim = args.edge_emb_dim 94 | gcn_mp_type = args.gcn_mp_type 95 | bucket_sz = args.bucket_sz 96 | gcn_input_dim = input_dim 97 | 98 | for i in range(num_layers - 1): 99 | conv = torch_geometric.nn.Sequential( 100 | "x, edge_index, edge_attr", 101 | [ 102 | ( 103 | MPGCNConv( 104 | gcn_input_dim, 105 | hidden_dim, 106 | edge_emb_dim, 107 | gcn_mp_type, 108 | bucket_sz, 109 | normalize=True, 110 | bias=True, 111 | ), 112 | "x, edge_index, edge_attr -> x", 113 | ), 114 | nn.Linear(hidden_dim, hidden_dim), 115 | nn.LeakyReLU(negative_slope=0.2), 116 | nn.BatchNorm1d(hidden_dim), 117 | ], 118 | ) 119 | gcn_input_dim = hidden_dim 120 | self.convs.append(conv) 121 | 122 | input_dim = 0 123 | 124 | if self.pooling == "concat": 125 | node_dim = 8 126 | conv = torch_geometric.nn.Sequential( 127 | "x, edge_index, edge_attr", 128 | [ 129 | ( 130 | MPGCNConv( 131 | hidden_dim, 132 | hidden_dim, 133 | edge_emb_dim, 134 | gcn_mp_type, 135 | bucket_sz, 136 | normalize=True, 137 | bias=True, 138 | ), 139 | "x, edge_index, edge_attr -> x", 140 | ), 141 | nn.Linear(hidden_dim, 64), 142 | nn.LeakyReLU(negative_slope=0.2), 143 | nn.Linear(64, node_dim), 144 | nn.LeakyReLU(negative_slope=0.2), 145 | nn.BatchNorm1d(node_dim), 146 | ], 147 | ) 148 | input_dim = node_dim * self.num_nodes 149 | 150 | elif self.pooling == "sum" or self.pooling == "mean": 151 | node_dim = 256 152 | input_dim = node_dim 153 | conv = torch_geometric.nn.Sequential( 154 | "x, edge_index, edge_attr", 155 | [ 156 | ( 157 | MPGCNConv( 158 | hidden_dim, 159 | hidden_dim, 160 | edge_emb_dim, 161 | gcn_mp_type, 162 | bucket_sz, 163 | normalize=True, 164 | bias=True, 165 | ), 166 | "x, edge_index, edge_attr -> x", 167 | ), 168 | nn.Linear(hidden_dim, hidden_dim), 169 | nn.LeakyReLU(negative_slope=0.2), 170 | nn.BatchNorm1d(node_dim), 171 | ], 172 | ) 173 | 174 | self.convs.append(conv) 175 | 176 | self.fcn = nn.Sequential( 177 | nn.Linear(input_dim, 256), 178 | nn.LeakyReLU(negative_slope=0.2), 179 | nn.Linear(256, 32), 180 | nn.LeakyReLU(negative_slope=0.2), 181 | nn.Linear(32, self.num_classes), 182 | ) 183 | 184 | def forward(self, data, edge_index, edge_attr, batch): 185 | z = data.x 186 | edge_attr = torch.abs(edge_attr) 187 | 188 | for i, conv in enumerate(self.convs): 189 | # bz*nodes, hidden 190 | z = conv(z, edge_index, edge_attr) 191 | if self.pooling == "concat": 192 | z = z.reshape((z.shape[0] // self.num_nodes, -1)) 193 | elif self.pooling == "sum": 194 | z = global_add_pool(z, batch) # [N, F] 195 | elif self.pooling == "mean": 196 | z = global_mean_pool(z, batch) # [N, F] 197 | 198 | out = self.fcn(z) 199 | return out 200 | -------------------------------------------------------------------------------- /src/dataset/transforms.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import torch 4 | from networkx.convert_matrix import from_numpy_matrix 5 | from node2vec import Node2Vec as Node2Vec_ 6 | from numpy import linalg as LA 7 | from torch_geometric.data import Data 8 | 9 | from .base_transform import BaseTransform 10 | from .brain_data import BrainData 11 | from .utils import LDP, binning 12 | 13 | 14 | class FromSVTransform(BaseTransform): 15 | def __init__(self, sv_transform): 16 | super(FromSVTransform, self).__init__() 17 | self.sv_transform = sv_transform 18 | 19 | def __call__(self, data): 20 | keys = list(filter(lambda x: x.startswith("edge_index"), data.keys)) 21 | for key in keys: 22 | if key.startswith("edge_index"): 23 | postfix = key[10:] 24 | edge_index = data[f"edge_index{postfix}"] 25 | edge_attr = data[f"edge_attr{postfix}"] 26 | svdata = Data( 27 | edge_index=edge_index, edge_attr=edge_attr, num_nodes=data.num_nodes 28 | ) 29 | svdata_transformed = self.sv_transform(svdata) 30 | data[f"x{postfix}"] = svdata_transformed.x 31 | data[f"edge_index{postfix}"] = svdata_transformed.edge_index 32 | data[f"edge_attr{postfix}"] = svdata_transformed.edge_attr 33 | return data 34 | 35 | def __str__(self): 36 | return self.sv_transform.__class__.__name__ 37 | 38 | 39 | class Identity(BaseTransform): 40 | def __call__(self, data: BrainData): 41 | """ 42 | Returns a diagonal matrix with ones on the diagonal. 43 | :param data: BrainData 44 | :return: torch.Tensor 45 | """ 46 | data.x = torch.diag(torch.ones(data.num_nodes)) 47 | return data 48 | 49 | 50 | class Degree(BaseTransform): 51 | def __call__(self, data: BrainData): 52 | """ 53 | Returns a diagonal matrix with the degree of each node on the diagonal. 54 | :param data: BrainData 55 | :return: torch.Tensor 56 | """ 57 | adj = torch.sparse_coo_tensor( 58 | data.edge_index, data.edge_attr, [data.num_nodes, data.num_nodes] 59 | ) 60 | adj = adj.to_dense() 61 | data.x = torch.Tensor(adj.sum(dim=1, keepdim=True)).float() 62 | return data 63 | 64 | def __str__(self): 65 | return "Degree" 66 | 67 | 68 | class LDPTransform(BaseTransform): 69 | def __call__(self, data: BrainData): 70 | """ 71 | Returns node feature with LDP transform. 72 | :param data: BrainData 73 | :return: torch.Tensor 74 | """ 75 | adj = torch.sparse_coo_tensor( 76 | data.edge_index, data.edge_attr, [data.num_nodes, data.num_nodes] 77 | ) 78 | adj = adj.to_dense() 79 | data.x = torch.Tensor(LDP(nx.from_numpy_array(adj.numpy()))).float() 80 | return data 81 | 82 | def __str__(self): 83 | return "LDP" 84 | 85 | 86 | class DegreeBin(BaseTransform): 87 | def __call__(self, data: BrainData): 88 | """ 89 | Returns node feature with degree bin transform. 90 | :param data: BrainData 91 | :return: torch.Tensor 92 | """ 93 | adj = torch.sparse_coo_tensor( 94 | data.edge_index, data.edge_attr, [data.num_nodes, data.num_nodes] 95 | ) 96 | adj = adj.to_dense() 97 | return torch.Tensor(binning(adj.sum(dim=1))).float() 98 | 99 | def __str__(self): 100 | return "Degree_Bin" 101 | 102 | 103 | class Adj(BaseTransform): 104 | def __call__(self, data: BrainData): 105 | """ 106 | Returns adjacency matrix. 107 | :param data: BrainData 108 | :return: torch.Tensor 109 | """ 110 | adj = torch.sparse_coo_tensor( 111 | data.edge_index, data.edge_attr, [data.num_nodes, data.num_nodes] 112 | ) 113 | adj = adj.to_dense() 114 | data.x = adj 115 | return data 116 | 117 | def __str__(self): 118 | return "Adj" 119 | 120 | 121 | class Eigenvector(BaseTransform): 122 | def __call__(self, data: BrainData): 123 | """ 124 | Returns node feature with eigenvector. 125 | :param data: BrainData 126 | :return: torch.Tensor 127 | """ 128 | adj = torch.sparse_coo_tensor( 129 | data.edge_index, data.edge_attr, [data.num_nodes, data.num_nodes] 130 | ) 131 | adj = adj.to_dense() 132 | w, v = LA.eig(adj.numpy()) 133 | # indices = np.argsort(w)[::-1] 134 | v = v.transpose() 135 | data.x = torch.Tensor(v).float() 136 | return data 137 | 138 | 139 | class EigenNorm(BaseTransform): 140 | def __call__(self, data: BrainData): 141 | """ 142 | Returns node feature with eigen norm. 143 | :param data: BrainData 144 | :return: torch.Tensor 145 | """ 146 | adj = torch.sparse_coo_tensor( 147 | data.edge_index, data.edge_attr, [data.num_nodes, data.num_nodes] 148 | ) 149 | adj = adj.to_dense() 150 | sum_of_rows = adj.sum(dim=1) 151 | adj /= sum_of_rows 152 | adj = torch.nan_to_num(adj) 153 | w, v = LA.eig(adj.numpy()) 154 | # indices = np.argsort(w)[::-1] 155 | v = v.transpose() 156 | data.x = torch.Tensor(v).float() 157 | return data 158 | 159 | 160 | class Node2Vec(BaseTransform): 161 | def __init__( 162 | self, 163 | feature_dim=32, 164 | walk_length=5, 165 | num_walks=200, 166 | num_workers=4, 167 | window=10, 168 | min_count=1, 169 | batch_words=4, 170 | ): 171 | super(Node2Vec, self).__init__() 172 | self.feature_dim = feature_dim 173 | self.walk_length = walk_length 174 | self.num_walks = num_walks 175 | self.num_workers = num_workers 176 | self.window = window 177 | self.min_count = min_count 178 | self.batch_words = batch_words 179 | 180 | def __call__(self, data): 181 | """ 182 | Returns node feature with node2vec transform. 183 | :param data: BrainData 184 | :return: torch.Tensor 185 | """ 186 | adj = torch.sparse_coo_tensor( 187 | data.edge_index, data.edge_attr, [data.num_nodes, data.num_nodes] 188 | ) 189 | adj = adj.to_dense() 190 | if (adj < 0).int().sum() > 0: 191 | # split the adjacency matrix into two (negative and positive) parts 192 | pos_adj = adj.clone() 193 | pos_adj[adj < 0] = 0 194 | neg_adj = adj.clone() 195 | neg_adj[adj > 0] = 0 196 | neg_adj = -neg_adj 197 | adjs = [pos_adj, neg_adj] 198 | else: 199 | adjs = [adj] 200 | 201 | xs = [] 202 | for adj in adjs: 203 | x = torch.zeros((data.num_nodes, self.feature_dim)) 204 | graph = from_numpy_matrix(adj.numpy()) 205 | node2vec = Node2Vec_( 206 | graph, 207 | dimensions=self.feature_dim, 208 | walk_length=self.walk_length, 209 | num_walks=self.num_walks, 210 | workers=self.num_workers, 211 | ) 212 | model = node2vec.fit( 213 | window=self.window, 214 | min_count=self.min_count, 215 | batch_words=self.batch_words, 216 | ) 217 | for i in range(data.num_nodes): 218 | x[i] = torch.Tensor(model.wv[f"{i}"].copy()) 219 | xs.append(x) 220 | data.x = torch.cat(xs, dim=-1) 221 | return data 222 | 223 | def __str__(self): 224 | return "Node2Vec" 225 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # xGW-GAT 2 | 3 | This repository is the official implementation of `xGW-GAT`, an explainable, graph attention network for n-ary, transductive, classification tasks for functional brain connectomes and gait impairment severity. Our associated paper, **"An Explainable Geometric-Weighted Graph Attention Network for Identifying Functional Networks Associated with Gait Impairment"** has been accepted to MICCAI 2023 and is supported by the MICCAI 2023 STAR award. Check out our [preprint](https://arxiv.org/abs/2307.13108), [paper](https://link.springer.com/chapter/10.1007/978-3-031-43895-0_68), and our [oral talk](https://youtu.be/ZqoIfHHcIXc)! 4 | 5 | Our pipeline of three modules: 6 | 1) A stratified, learning-based sample selection method leveraging Riemannian metrics for connectome similarity comparisons 7 | 2) An attention-based, brain network-oriented prediction model 8 | 3) An explanation generator for individual and global attention masks that highlight salient Regions of Interest (ROIs) in predicting gait impairment severity states. 9 | 10 |

11 | Screenshot 2023-08-01 at 12 19 32 PM 12 |

13 | 14 | > **An Explainable Geometric-Weighted Graph Attention Network for Identifying Functional Networks Associated with Gait Impairment** 15 | > 16 | > [Favour Nerrise](mailto:fnerrise@stanford.edu)1, [Qingyu Zhao]()2, [Kathleen L. Poston]()3, [Kilian M. Pohl]()2, [Ehsan Adeli]()2
17 | > 1Department of Electrical Engineering, Stanford University, Stanford, CA, USA
18 | > 2Department of Psychiatry and Behavioral Sciences, Stanford University, Stanford, CA
19 | > 3Dept. of Neurology and Neurological Sciences, Stanford University, Stanford, CA, USA
20 | > 21 | > **Abstract:** *One of the hallmark symptoms of Parkinson's Disease (PD) is the progressive loss of postural reflexes, which eventually leads to gait difficulties and balance problems. Identifying disruptions in brain function associated with gait impairment could be crucial in better understanding PD motor progression, thus advancing the development of more effective and personalized therapeutics. In this work, we present an explainable, geometric, weighted-graph attention neural network (xGW-GAT) to identify functional networks predictive of the progression of gait difficulties in individuals with PD. xGW-GAT predicts the multi-class gait impairment on the MDS Unified PD Rating Scale (MDS-UPDRS). Our computational- and data-efficient model represents functional connectomes as symmetric positive definite (SPD) matrices on a Riemannian manifold to explicitly encode pairwise interactions of entire connectomes, based on which we learn an attention mask yielding individual- and group-level explainability. Applied to our resting-state functional MRI (rs-fMRI) dataset of individuals with PD, xGW-GAT identifies functional connectivity patterns associated with gait impairment in PD and offers interpretable explanations of functional subnetworks associated with motor impairment. Our model successfully outperforms several existing methods while simultaneously revealing clinically relevant connectivity patterns. 22 | 23 | ## Installation Instructions 24 | - Download the ZIP folder or copy this repository, e.g. ```git clone https://github.com/favour-nerrise/xGW-GAT.git```. 25 | 26 | ### Dependencies 27 | This code was prepared using Python 3.10.4 and depends on the following packages: 28 | 29 | * torch==2.0.1 30 | * #PyG 31 | * -f https://data.pyg.org/whl/torch-2.0.1+cu118.html 32 | * torch_geometric==2.3.1 33 | * torch-cluster 34 | * torch-scatter 35 | * torch-sparse 36 | * torch-spline-conv 37 | * scikit-learn >= 0.24.1 38 | * pymanopt==2.1.1 39 | * numpy 40 | * pandas==2.0.3 41 | * scipy==1.11.1 42 | 43 | See more details and install all required packages using ```pip install -r requirements.txt```. We recommend running all code and making installations in a virtual environment to prevent package conflicts. See [here](https://docs.python.org/3/library/venv.html) for more details on how to do so. 44 | 45 | ## Getting Started 46 | ### Prepare Your Data 47 | * Extract functional correlation matrices from your chosen dataset, e.g. PPMI, save it as an ```.npy``` file, and place the dataset files in the ```./datasets/``` folder under the root folder. The saved matrix should be of shape ```(num_subjects, node_dim, node_dim)```. 48 | * Save associated subject metrics, e.g. gait impairment severity score., as an ```.npy``` file and also place them in the ```./datasets/``` folder. The saved matrix should be of shape ```(num_subjects)```. 49 | * Configure the ```brain_dataset.py``` and related files in the associated folder to correctly read in and process your dataset. Code has been provided for our use case of the ```PRIVATE``` dataset. 50 | 51 | ## Calling the Model 52 | 53 | ```bash 54 | python main.py --dataset= --model_name=gatv2 --sample_selection --explain 55 | ``` 56 | The --explain argument is optional and triggers providing attention-based explanations of your model's predictions and saves related explanation data to the ```outputs/explanations/``` folder. 57 | 58 | ## Configuration Options 59 | 60 | Different configurations for the models and dataset can be specified in the ```main.py``` file, such as ```num_epochs```, ```num_classes```, and ```hidden_dim```. 61 | 62 | ## Hyperparameter Tuning 63 | This pipeline was configured for hyperparameter optimization with [nni](https://github.com/microsoft/nni). Tuning configurations can be modified in the ```src/nni_configs.config.yml``` file. Using a Colab/Jupyter notebook, this can be done as follows: 64 | Create a free [ngrok.com](https://ngrok.com/) account and copy your *AuthToken* to be able to use the UI. Then run the following lines. 65 | ``` 66 | ! wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip # download ngrok and unzip it 67 | ! unzip ngrok-stable-linux-amd64.zip 68 | ``` 69 | ``` 70 | ! ./ngrok authtoken 71 | ``` 72 | ``` 73 | ! nnictl create --config src/nni_configs/config.yml --port 5000 & 74 | ``` 75 | ``` 76 | get_ipython().system_raw('./ngrok http 5000 &') # get experiment id 77 | ``` 78 | ``` 79 | ! curl -s http://localhost:4040/api/tunnels # don't change the port number 4040 80 | ``` 81 | ``` 82 | !nnictl stop # stop running experiment 83 | ``` 84 | 85 | 86 | ## Acknowledgments 87 | This work was partially supported by NIH grants (AA010723, NS115114, P30AG066515), Stanford School of Medicine Department of Psychiatry and Behavioral Sciences Jaswa Innovator Award, UST (a Stanford AI Lab alliance member), and the Stanford Institute for Human-Centered AI (HAI) Google Cloud credits.} FN is funded by the Stanford Graduate Fellowship and the Stanford NeuroTech Training Program Fellowship. 88 | 89 | This code was developed by Favour Nerrise (fnerrise@stanford.edu). We also thank [@Henny-Jie](https://github.com/HennyJie/), [@basiralab](https://github.com/basiralab/), and [@pyg-team](https://github.com/pyg-team/) for their related works and open-source code on [IBGNN](https://github.com/HennyJie/IBGNN) + [BrainGB](https://github.com/HennyJie/BrainGB), [RegGNN](https://github.com/basiralab/RegGNN/), and [Pytorch Geometric](https://github.com/pyg-team/pytorch_geometric), respectively, which served as great resources for developing our methods and codebase. 90 | 91 | 92 | ## Citation 93 | Please cite our paper when using **xGW-GAT**: 94 | ```latex 95 | @inproceedings{nerrise2023explainable, 96 | title={An Explainable Geometric-Weighted Graph Attention Network for Identifying Functional Networks Associated with Gait Impairment}, 97 | author={Nerrise, Favour and Zhao, Qingyu and Poston, Kathleen L and Pohl, Kilian M and Adeli, Ehsan}, 98 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 99 | pages={723--733}, 100 | year={2023}, 101 | organization={Springer} 102 | } 103 | ``` 104 | -------------------------------------------------------------------------------- /nni_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | HPO Quickstart with PyTorch 3 | =========================== 4 | This tutorial optimizes the model in `official PyTorch quickstart`_ with auto-tuning. 5 | 6 | The tutorial consists of 4 steps: 7 | 8 | 1. Modify the model for auto-tuning. 9 | 2. Define hyperparameters' search space. 10 | 3. Configure the experiment. 11 | 4. Run the experiment. 12 | 13 | .. _official PyTorch quickstart: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html 14 | """ 15 | from nni.experiment import Experiment 16 | 17 | # %% 18 | # Step 1: Prepare the model 19 | # ------------------------- 20 | # In first step, we need to prepare the model to be tuned. 21 | # 22 | # The model should be put in a separate script. 23 | # It will be evaluated many times concurrently, 24 | # and possibly will be trained on distributed platforms. 25 | # 26 | # In this tutorial, the model is defined in :doc:`model.py `. 27 | # 28 | # In short, it is a PyTorch model with 3 additional API calls: 29 | # 30 | # 1. Use :func:`nni.get_next_parameter` to fetch the hyperparameters to be evalutated. 31 | # 2. Use :func:`nni.report_intermediate_result` to report per-epoch accuracy metrics. 32 | # 3. Use :func:`nni.report_final_result` to report final accuracy. 33 | # 34 | # Please understand the model code before continue to next step. 35 | 36 | # %% 37 | # Step 2: Define search space 38 | # --------------------------- 39 | # In model code, we have prepared 3 hyperparameters to be tuned: 40 | # *features*, *lr*, and *momentum*. 41 | # 42 | # Here we need to define their *search space* so the tuning algorithm can sample them in desired range. 43 | # 44 | # Assuming we have following prior knowledge for these hyperparameters: 45 | # 46 | # 1. *features* should be one of 128, 256, 512, 1024. 47 | # 2. *lr* should be a float between 0.0001 and 0.1, and it follows exponential distribution. 48 | # 3. *momentum* should be a float between 0 and 1. 49 | # 50 | # In NNI, the space of *features* is called ``choice``; 51 | # the space of *lr* is called ``loguniform``; 52 | # and the space of *momentum* is called ``uniform``. 53 | # You may have noticed, these names are derived from ``numpy.random``. 54 | # 55 | # For full specification of search space, check :doc:`the reference `. 56 | # 57 | # Now we can define the search space as follow: 58 | 59 | search_space = { 60 | # 'features': {'_type': 'choice', '_value': [128, 256, 512, 1024]}, 61 | "lr": {"_type": "loguniform", "_value": [0.0001, 0.1]}, 62 | # 'momentum': {'_type': 'uniform', '_value': [0, 1]}, 63 | } 64 | 65 | # %% 66 | # Step 3: Configure the experiment 67 | # -------------------------------- 68 | # NNI uses an *experiment* to manage the HPO process. 69 | # The *experiment config* defines how to train the models and how to explore the search space. 70 | # 71 | # In this tutorial we use a *local* mode experiment, 72 | # which means models will be trained on local machine, without using any special training platform. 73 | 74 | experiment = Experiment("local") 75 | 76 | # %% 77 | # Now we start to configure the experiment. 78 | # 79 | # Configure trial code 80 | # ^^^^^^^^^^^^^^^^^^^^ 81 | # In NNI evaluation of each hyperparameter set is called a *trial*. 82 | # So the model script is called *trial code*. 83 | experiment.config.trial_command = "python main.py --enable_nni" 84 | experiment.config.trial_code_directory = "." 85 | # %% 86 | # When ``trial_code_directory`` is a relative path, it relates to current working directory. 87 | # To run ``main.py`` in a different path, you can set trial code directory to ``Path(__file__).parent``. 88 | # (`__file__ `__ 89 | # is only available in standard Python, not in Jupyter Notebook.) 90 | # 91 | # .. attention:: 92 | # 93 | # If you are using Linux system without Conda, 94 | # you may need to change ``"python model.py"`` to ``"python3 model.py"``. 95 | 96 | # %% 97 | # Configure search space 98 | # ^^^^^^^^^^^^^^^^^^^^^^ 99 | experiment.config.search_space = search_space 100 | 101 | # %% 102 | # Configure tuning algorithm 103 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^ 104 | # Here we use :doc:`TPE tuner `. 105 | experiment.config.tuner.name = "TPE" 106 | experiment.config.tuner.class_args["optimize_mode"] = "maximize" 107 | 108 | # %% 109 | # Configure how many trials to run 110 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 111 | # Here we evaluate 10 sets of hyperparameters in total, and concurrently evaluate 2 sets at a time. 112 | experiment.config.max_trial_number = 4 113 | experiment.config.trial_concurrency = 2 114 | experiment.config.max_experiment_duration = "1h" 115 | # %% 116 | # You may also set ``max_experiment_duration = '1h'`` to limit running time. 117 | # 118 | # If neither ``max_trial_number`` nor ``max_experiment_duration`` are set, 119 | # the experiment will run forever until you press Ctrl-C. 120 | # 121 | # .. note:: 122 | # 123 | # ``max_trial_number`` is set to 10 here for a fast example. 124 | # In real world it should be set to a larger number. 125 | # With default config TPE tuner requires 20 trials to warm up. 126 | 127 | # %% 128 | # Step 4: Run the experiment 129 | # -------------------------- 130 | # Now the experiment is ready. Choose a port and launch it. (Here we use port 8080.) 131 | # 132 | # You can use the web portal to view experiment status: http://localhost:8080. 133 | experiment.run(8084) 134 | 135 | # %% 136 | # After the experiment is done 137 | # ---------------------------- 138 | # Everything is done and it is safe to exit now. The following are optional. 139 | # 140 | # If you are using standard Python instead of Jupyter Notebook, 141 | # you can add ``input()`` or ``signal.pause()`` to prevent Python from exiting, 142 | # allowing you to view the web portal after the experiment is done. 143 | 144 | # input('Press enter to quit') 145 | experiment.save() 146 | 147 | experiment.stop() 148 | 149 | # %% 150 | # :meth:`nni.experiment.Experiment.stop` is automatically invoked when Python exits, 151 | # so it can be omitted in your code. 152 | # 153 | # After the experiment is stopped, you can run :meth:`nni.experiment.Experiment.view` to restart web portal. 154 | # 155 | # .. tip:: 156 | # 157 | # This example uses :doc:`Python API ` to create experiment. 158 | # 159 | # You can also create and manage experiments with :doc:`command line tool <../hpo_nnictl/nnictl>`. 160 | 161 | 162 | # def test_seed(script_name: str, dataset='SU', enable_nni=False, num_trials=1000, 163 | # node_features='degree_bin'): 164 | # seeds = [] 165 | # num_trials = 10 if enable_nni else num_trials 166 | # print(f'running {num_trials} trials') 167 | # for i in range(num_trials): 168 | # seeds.append(random.randint(100000, 10000000)) 169 | 170 | # default_param = { 171 | # 'dataset_name': dataset, 172 | # 'node_features': node_features, 173 | # 'weight_decay': 5e-4, 174 | # 'epochs': 100, 175 | # 'n_MLP_layer': 1, 176 | # 'n_GNN_layers': 3, 177 | # # 'hidden_dim': 360, 178 | # # 'edge_emb_dim':256, 179 | # 'lr': 0.001, 180 | # } 181 | 182 | # sp = SP.SimpleParam(default=default_param) 183 | # params = sp(from_='./src/nni_configs/search_space.json', preprocess_nni=False) 184 | 185 | # param_str = ' '.join([f'--{k} {v}' for k, v in params.items()]) 186 | 187 | # cmd = f'python {script_name} {param_str}' 188 | # cmd += ' --enable_nni' if enable_nni else '' 189 | # print(cmd) 190 | # os.system(cmd) 191 | 192 | 193 | # if __name__ == '__main__': 194 | # parser = argparse.ArgumentParser() 195 | # parser.add_argument('--target', type=str, default='main.py') 196 | # parser.add_argument('--enable_nni', action='store_true') 197 | # parser.add_argument('--dataset', type=str, default='SU') 198 | # parser.add_argument('--trials', type=int, default=1000) 199 | # parser.add_argument('--node_features', type=str, 200 | # choices=['identity', 'degree', 'degree_bin', 'LDP', 'node2vec', 'adj'], 201 | # default='adj') 202 | # args = parser.parse_args() 203 | 204 | # cwd = os.getcwd() 205 | # print(cwd) 206 | 207 | # test_seed(args.target, dataset=args.dataset, enable_nni=args.enable_nni, num_trials=args.trials, 208 | # node_features=args.node_features) 209 | -------------------------------------------------------------------------------- /src/utils/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BinaryFocalLoss(nn.Module): 8 | """ 9 | One possible pytorch implementation of focal loss (https://arxiv.org/abs/1708.02002) for binary classification. The 10 | idea is that the binary cross entropy is weighted by one minus the estimated probability, so that more confident 11 | predictions are given less weight in the loss function 12 | with_logits controls whether the binary cross-entropy is computed with F.binary_cross_entropy_with_logits or 13 | F.binary_cross_entropy (whether logits are passed in or probabilities) 14 | reduce = 'none', 'mean', 'sum' 15 | """ 16 | 17 | def __init__(self, alpha=1, gamma=1, with_logits=False, reduce="mean"): 18 | super().__init__() 19 | self.alpha = alpha 20 | self.gamma = gamma 21 | self.with_logits = with_logits 22 | self.reduce = reduce 23 | 24 | def forward(self, inputs, targets): 25 | if self.with_logits: 26 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) 27 | else: 28 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False) 29 | pt = torch.exp(-BCE_loss) 30 | focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss 31 | 32 | if self.reduce == "mean": 33 | return torch.mean(focal_loss) 34 | elif self.reduce == "sum": 35 | return torch.sum(focal_loss) 36 | else: 37 | return focal_loss 38 | 39 | 40 | class FocalLoss(nn.Module): 41 | """ 42 | One possible pytorch implementation of focal loss (https://arxiv.org/abs/1708.02002), for multiclass classification. 43 | This module is intended to be easily swappable with nn.CrossEntropyLoss. 44 | If with_logits is true, then input is expected to be a tensor of raw logits, and a softmax is applied 45 | If with_logits is false, then input is expected to be a tensor of probabiltiies 46 | target is expected to be a batch of integer targets (i.e. sparse, not one-hot). This is the same behavior as 47 | nn.CrossEntropyLoss. 48 | This loss also ignores contributions where target == ignore_index, in the same way as nn.CrossEntropyLoss 49 | batch behaviour: reduction = 'none', 'mean', 'sum' 50 | """ 51 | 52 | def __init__( 53 | self, 54 | gamma=1, 55 | eps=1e-7, 56 | with_logits=True, 57 | ignore_index=-100, 58 | reduction="mean", 59 | smooth_eps=None, 60 | ): 61 | super().__init__() 62 | 63 | assert reduction in [ 64 | "none", 65 | "mean", 66 | "sum", 67 | ], "FocalLoss: reduction must be one of ['none', 'mean', 'sum']" 68 | 69 | self.gamma = gamma 70 | self.eps = eps 71 | self.with_logits = with_logits 72 | self.ignore_index = ignore_index 73 | self.reduction = reduction 74 | self.smooth_eps = smooth_eps 75 | 76 | def forward(self, input, target): 77 | return focal_loss( 78 | input, 79 | target, 80 | self.gamma, 81 | self.eps, 82 | self.with_logits, 83 | self.ignore_index, 84 | self.reduction, 85 | smooth_eps=self.smooth_eps, 86 | ) 87 | 88 | 89 | def focal_loss( 90 | input, 91 | target, 92 | gamma=1, 93 | eps=1e-7, 94 | with_logits=True, 95 | ignore_index=-100, 96 | reduction="mean", 97 | smooth_eps=None, 98 | ): 99 | """ 100 | A function version of focal loss, meant to be easily swappable with F.cross_entropy. The equation implemented here 101 | is L_{focal} = - \sum (1 - p_{target})^\gamma p_{target} \log p_{pred} 102 | If with_logits is true, then input is expected to be a tensor of raw logits, and a softmax is applied 103 | If with_logits is false, then input is expected to be a tensor of probabiltiies 104 | target is expected to be a batch of integer targets (i.e. sparse, not one-hot). This is the same behavior as 105 | nn.CrossEntropyLoss. 106 | Loss is ignored at indices where the target is equal to ignore_index 107 | batch behaviour: reduction = 'none', 'mean', 'sum' 108 | """ 109 | smooth_eps = smooth_eps or 0 110 | 111 | # make target 112 | y = F.one_hot(target, input.size(-1)) 113 | 114 | # apply label smoothing according to target = [eps/K, eps/K, ..., (1-eps) + eps/K, eps/K, eps/K, ...] 115 | if smooth_eps > 0: 116 | y = y * (1 - smooth_eps) + smooth_eps / y.size(-1) 117 | 118 | if with_logits: 119 | pt = F.softmax(input, dim=-1) 120 | else: 121 | pt = input 122 | 123 | pt = pt.clamp( 124 | eps, 1.0 - eps 125 | ) # a hack-y way to prevent taking the log of a zero, because we might be dealing with 126 | # probabilities directly. 127 | 128 | loss = -y * torch.log(pt) # cross entropy 129 | loss *= (1 - pt) ** gamma # focal loss factor 130 | loss = torch.sum(loss, dim=-1) 131 | 132 | # mask the logits so that values at indices which are equal to ignore_index are ignored 133 | loss = loss[target != ignore_index] 134 | 135 | # batch reduction 136 | if reduction == "mean": 137 | return torch.mean(loss, dim=-1) 138 | elif reduction == "sum": 139 | return torch.sum(loss, dim=-1) 140 | else: # 'none' 141 | return loss 142 | 143 | 144 | def _similarity(z1: torch.Tensor, z2: torch.Tensor): 145 | z1 = F.normalize(z1) 146 | z2 = F.normalize(z2) 147 | return z1 @ z2.t() 148 | 149 | 150 | def nt_xent_loss( 151 | z1: torch.Tensor, z2: torch.Tensor, batch_size: int, temperature: float 152 | ): 153 | # Space complexity: O(BN) (semi_loss: O(N^2)) 154 | device = z1.device 155 | num_nodes = z1.size(0) 156 | num_batches = (num_nodes - 1) // batch_size + 1 157 | f = lambda x: torch.exp(x / temperature) 158 | indices = torch.arange(0, num_nodes).to(device) 159 | losses = [] 160 | 161 | for i in range(num_batches): 162 | batch_mask = indices[i * batch_size : (i + 1) * batch_size] 163 | intra_similarity = f(_similarity(z1[batch_mask], z1)) # [B, N] 164 | inter_similarity = f(_similarity(z1[batch_mask], z2)) # [B, N] 165 | 166 | positives = inter_similarity[:, batch_mask].diag() 167 | negatives = ( 168 | intra_similarity.sum(dim=1) 169 | + inter_similarity.sum(dim=1) 170 | - intra_similarity[:, batch_mask].diag() 171 | ) 172 | 173 | losses.append(-torch.log(positives / negatives)) 174 | 175 | return torch.cat(losses) 176 | 177 | 178 | def debiased_nt_xent_loss( 179 | z1: torch.Tensor, z2: torch.Tensor, tau: float, tau_plus: float 180 | ): 181 | f = lambda x: torch.exp(x / tau) 182 | intra_similarity = f(_similarity(z1, z1)) 183 | inter_similarity = f(_similarity(z1, z2)) 184 | 185 | pos = inter_similarity.diag() 186 | neg = ( 187 | intra_similarity.sum(dim=1) 188 | - intra_similarity.diag() 189 | + inter_similarity.sum(dim=1) 190 | - inter_similarity.diag() 191 | ) 192 | 193 | num_neg = z1.size()[0] * 2 - 2 194 | ng = (-num_neg * tau_plus * pos + neg) / (1 - tau_plus) 195 | ng = torch.clamp(ng, min=num_neg * np.e ** (-1.0 / tau)) 196 | 197 | return -torch.log(pos / (pos + ng)) 198 | 199 | 200 | def hardness_nt_xent_loss( 201 | z1: torch.Tensor, z2: torch.Tensor, tau: float, tau_plus: float, beta: float 202 | ): 203 | f = lambda x: torch.exp(x / tau) 204 | intra_similarity = f(_similarity(z1, z1)) 205 | inter_similarity = f(_similarity(z1, z2)) 206 | 207 | pos = inter_similarity.diag() 208 | neg = ( 209 | intra_similarity.sum(dim=1) 210 | - intra_similarity.diag() 211 | + inter_similarity.sum(dim=1) 212 | - inter_similarity.diag() 213 | ) 214 | 215 | num_neg = z1.size()[0] * 2 - 2 216 | imp = (beta * neg.log()).exp() 217 | reweight_neg = (imp * neg) / neg.mean() 218 | ng = (-num_neg * tau_plus * pos + reweight_neg) / (1 - tau_plus) 219 | ng = torch.clamp(ng, min=num_neg * np.e ** (-1.0 / tau)) 220 | 221 | return -torch.log(pos / (pos + ng)) 222 | 223 | 224 | def jsd_loss(z1, z2, discriminator, pos_mask, neg_mask=None): 225 | if neg_mask is None: 226 | neg_mask = 1 - pos_mask 227 | num_neg = neg_mask.int().sum() 228 | num_pos = pos_mask.int().sum() 229 | similarity = discriminator(z1, z2) 230 | 231 | E_pos = (np.log(2) - F.softplus(-similarity * pos_mask)).sum() 232 | E_pos /= num_pos 233 | neg_similarity = similarity * neg_mask 234 | E_neg = (F.softplus(-neg_similarity) + neg_similarity - np.log(2)).sum() 235 | E_neg /= num_neg 236 | 237 | return E_neg - E_pos 238 | -------------------------------------------------------------------------------- /src/models/GATv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch_geometric.nn import (GATv2Conv, Sequential, global_add_pool, 5 | global_mean_pool) 6 | from torch_geometric.utils import softmax 7 | 8 | from src.utils.attention import save_attn_scores 9 | 10 | 11 | class MPGATConv(GATv2Conv): 12 | """ 13 | Adapted from BrainGB 14 | 15 | """ 16 | 17 | def __init__( 18 | self, 19 | in_channels, 20 | out_channels, 21 | heads, 22 | dropout=0.0, 23 | edge_dim=1, 24 | gat_mp_type: str = "attention_weighted", 25 | ): 26 | super().__init__(in_channels, out_channels, heads) 27 | self.dropout = dropout 28 | self.gat_mp_type = gat_mp_type 29 | self.edge_dim = edge_dim 30 | 31 | if edge_dim is not None: 32 | self.lin_edge = torch.nn.Linear(edge_dim, heads * out_channels, bias=False) 33 | 34 | def message(self, x_j, x_i, edge_attr, index, ptr, size_i): 35 | x = x_i + x_j 36 | 37 | if edge_attr is not None: 38 | if edge_attr.dim() == 1: 39 | edge_attr = edge_attr.view(-1, 1) 40 | assert self.lin_edge is not None 41 | edge_attr = self.lin_edge(edge_attr) 42 | edge_attr = edge_attr.view(-1, self.heads, self.out_channels) 43 | x = x + edge_attr 44 | 45 | x = F.leaky_relu(x, self.negative_slope) 46 | alpha = (x * self.att).sum(dim=-1) 47 | alpha = softmax(alpha, index, ptr, size_i) 48 | self._alpha = alpha 49 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 50 | 51 | attention_score = alpha.unsqueeze(-1) 52 | edge_weights = torch.abs(edge_attr.view(-1, 1).unsqueeze(-1)) 53 | 54 | if self.gat_mp_type == "attention_weighted": 55 | # (1) att: s^(l+1) = s^l * alpha 56 | msg = x_j * attention_score 57 | return msg 58 | elif self.gat_mp_type == "attention_edge_weighted": 59 | # (2) e-att: s^(l+1) = s^l * alpha * e 60 | msg = x_j * attention_score * edge_weights 61 | return msg 62 | elif self.gat_mp_type == "sum_attention_edge": 63 | # (3) m-att-1: s^(l+1) = s^l * (alpha + e), this one may not make sense cause it doesn't used attention score to control all 64 | msg = x_j * (attention_score + edge_weights) 65 | return msg 66 | elif self.gat_mp_type == "edge_node_concat": 67 | # (4) m-att-2: s^(l+1) = linear(concat(s^l, e) * alpha) 68 | msg = torch.cat( 69 | [ 70 | x_i, 71 | x_j * attention_score, 72 | edge_attr.view(-1, 1).unsqueeze(-1).expand(-1, self.heads, -1), 73 | ], 74 | dim=-1, 75 | ) 76 | msg = self.lin_edge(msg) 77 | return msg 78 | elif self.gat_mp_type == "node_concat": 79 | # (4) m-att-2: s^(l+1) = linear(concat(s^l, e) * alpha) 80 | msg = torch.cat([x_i, x_j * attention_score], dim=-1) 81 | msg = self.lin_edge(msg) 82 | return msg 83 | elif self.gat_mp_type == "sum_node_edge_weighted": 84 | # (5) m-att-3: s^(l+1) = (s^l + e) * alpha 85 | node_emb_dim = x_j.shape[-1] 86 | extended_edge = torch.cat([edge_weights] * node_emb_dim, dim=-1) 87 | sum_node_edge = x_j + extended_edge 88 | msg = sum_node_edge * attention_score 89 | return msg 90 | else: 91 | raise ValueError(f"Invalid message passing variant {self.gat_mp_type}") 92 | 93 | 94 | class GATv2(nn.Module): 95 | """ 96 | The graph attentional operator from the "Graph Attention Networks" 97 | 98 | """ 99 | 100 | def __init__(self, in_channels, args): 101 | super().__init__() 102 | self.activation = torch.nn.ReLU() 103 | self.convs = torch.nn.ModuleList() 104 | 105 | self.dropout = args.dropout 106 | self.edge_dim = args.edge_emb_dim 107 | self.explain = args.explain 108 | self.hidden_dim = args.hidden_dim 109 | self.gat_mp_type = args.gat_mp_type 110 | self.pooling = args.pooling 111 | self.num_classes = args.num_classes 112 | self.num_heads = args.num_heads 113 | self.num_layers = args.n_GNN_layers 114 | self.num_nodes = args.num_nodes 115 | 116 | gat_input_dim = in_channels 117 | 118 | for i in range(self.num_layers - 1): 119 | conv = Sequential( 120 | "x, edge_index, edge_attr", 121 | [ 122 | ( 123 | MPGATConv( 124 | in_channels, 125 | self.hidden_dim, 126 | self.num_heads, 127 | dropout=self.dropout, 128 | gat_mp_type=self.gat_mp_type, 129 | ), 130 | "x, edge_index, edge_attr -> x", 131 | ), 132 | nn.Linear(self.hidden_dim * self.num_heads, self.hidden_dim), 133 | nn.LeakyReLU(negative_slope=0.2), 134 | nn.BatchNorm1d(self.hidden_dim), 135 | ], 136 | ) 137 | gat_input_dim = self.hidden_dim 138 | self.convs.append(conv) 139 | 140 | in_channels = 0 141 | 142 | if self.pooling == "concat": 143 | node_dim = self.hidden_dim 144 | conv = Sequential( 145 | "x, edge_index, edge_attr", 146 | [ 147 | ( 148 | MPGATConv( 149 | in_channels, 150 | self.hidden_dim, 151 | self.num_heads, 152 | dropout=self.dropout, 153 | gat_mp_type=self.gat_mp_type, 154 | ), 155 | "x, edge_index, edge_attr -> x", 156 | ), 157 | nn.Linear(self.hidden_dim * self.num_heads, 64), 158 | nn.LeakyReLU(negative_slope=0.2), 159 | nn.Linear(64, node_dim), 160 | nn.LeakyReLU(negative_slope=0.2), 161 | nn.BatchNorm1d(node_dim), 162 | ], 163 | ) 164 | in_channels = node_dim * self.num_nodes 165 | 166 | elif self.pooling == "sum" or self.pooling == "mean": 167 | node_dim = self.hidden_dim 168 | in_channels = node_dim 169 | conv = Sequential( 170 | "x, edge_index, edge_attr", 171 | [ 172 | ( 173 | MPGATConv( 174 | in_channels, 175 | self.hidden_dim, 176 | self.num_heads, 177 | dropout=self.dropout, 178 | gat_mp_type=self.gat_mp_type, 179 | ), 180 | "x, edge_index, edge_attr -> x", 181 | ), 182 | nn.Linear(self.hidden_dim * self.num_heads, self.hidden_dim), 183 | nn.LeakyReLU(negative_slope=0.2), 184 | nn.BatchNorm1d(node_dim), 185 | ], 186 | ) 187 | 188 | self.convs.append(conv) 189 | 190 | self.fcn = nn.Sequential( 191 | nn.Linear(in_channels, 256), 192 | nn.LeakyReLU(negative_slope=0.2), 193 | nn.Linear(256, 32), 194 | nn.LeakyReLU(negative_slope=0.2), 195 | nn.Linear(32, self.num_classes), 196 | ) 197 | 198 | def forward(self, data, edge_index, edge_attr, batch): 199 | x = data.x 200 | edge_attr = torch.abs(edge_attr) 201 | 202 | for i, conv in enumerate(self.convs): 203 | # bz*nodes, hidden 204 | if self.explain and ( 205 | data.num_nodes == self.num_nodes 206 | ): # save attention only when explaining 207 | if i == self.num_layers - 1: 208 | x, attn = conv( 209 | x, edge_index, edge_attr, return_attention_weights=True 210 | ) 211 | save_attn_scores( 212 | attn, data 213 | ) # attn = (edge_index, alpha coefficients) 214 | else: 215 | x = conv(x, edge_index, edge_attr) 216 | # else: 217 | x = conv(x, edge_index, edge_attr) 218 | 219 | if self.pooling == "concat": 220 | x = x.reshape((x.shape[0] // self.num_nodes, -1)) 221 | elif self.pooling == "sum": 222 | x = global_add_pool(x, batch) # [N, F] 223 | elif self.pooling == "mean": 224 | x = global_mean_pool(x, batch) # [N, F] 225 | 226 | out = self.fcn(x) 227 | return out 228 | -------------------------------------------------------------------------------- /src/utils/SPD.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for computing topological features in Riemannian space. 3 | 4 | Code taken from https://morphomatics.github.io/, 5 | created by Felix Ambellan and Martin Hanik and Christoph von Tycowicz, 2021. 6 | """ 7 | 8 | import numpy as np 9 | import numpy.linalg as la 10 | from pymanopt.manifolds.manifold import Manifold 11 | from pymanopt.tools.multi import multisym 12 | from scipy.linalg import expm_frechet, logm 13 | 14 | 15 | class SPD: 16 | """Returns the product manifold Sym+(d)^k, i.e., a product of k dxd symmetric positive matrices (SPD). 17 | 18 | manifold = SPD(k, d) 19 | 20 | Elements of Sym+(d)^k are represented as arrays of size kxdxd where every dxd slice is an SPD matrix, i.e., a 21 | symmetric matrix S with positive eigenvalues. 22 | 23 | The Riemannian metric used is the product Log-Euclidean metric that is induced by the standard Euclidean trace 24 | metric; see 25 | Arsigny, V., Fillard, P., Pennec, X., and Ayache., N. 26 | Fast and simple computations on tensors with Log-Euclidean metrics. 27 | """ 28 | 29 | def __init__(self, k=1, d=3): 30 | if d <= 0: 31 | raise RuntimeError("d must be an integer no less than 1.") 32 | 33 | if k == 1: 34 | self._name = ( 35 | "Manifold of symmetric positive definite {d} x {d} matrices".format( 36 | d=d, k=k 37 | ) 38 | ) 39 | elif k > 1: 40 | self._name = "Manifold of {k} symmetric positive definite {d} x {d} matrices (Sym^+({d}))^{k}".format( 41 | d=d, k=k 42 | ) 43 | else: 44 | raise RuntimeError("k must be an integer no less than 1.") 45 | 46 | self._k = k 47 | self._d = d 48 | 49 | def __str__(self): 50 | return self._name 51 | 52 | @property 53 | def dim(self): 54 | return int((self._d * (self._d + 1) / 2) * self._k) 55 | 56 | @property 57 | def typicaldist(self): 58 | # typical affine invariant distance 59 | return np.sqrt(self._k * 6) 60 | 61 | def inner(self, S, X, Y): 62 | """product metric""" 63 | return np.sum(np.einsum("...ij,...ij", X, Y)) 64 | 65 | def norm(self, S, X): 66 | """norm from product metric""" 67 | return np.sqrt(self.inner(S, X, X)) 68 | 69 | def proj(self, X, H): 70 | """orthogonal (with respect to the Euclidean inner product) projection of ambient 71 | vector ((k,3,3) array) onto the tangent space at X""" 72 | return dlog(X, multisym(H)) 73 | 74 | def egrad2rgrad(self, X, D): 75 | # should be adj_dexp instead of dexp (however, dexp appears to be self-adjoint for symmetric matrices) 76 | return dexp(log_mat(X), multisym(D)) 77 | 78 | def ehess2rhess(self, X, Hess): 79 | # TODO 80 | return 81 | 82 | def exp(self, S, X): 83 | """Riemannian exponential with base point S evaluated at X""" 84 | assert S.shape == X.shape 85 | 86 | # (avoid additional exp/log) 87 | Y = X + log_mat(S) 88 | vals, vecs = la.eigh(Y) 89 | return np.einsum("...ij,...j,...kj", vecs, np.exp(vals), vecs) 90 | 91 | retr = exp 92 | 93 | def log(self, S, U): 94 | """Riemannian logarithm with base point S evaluated at U""" 95 | assert S.shape == U.shape 96 | 97 | # (avoid additional log/exp) 98 | return log_mat(U) - log_mat(S) 99 | 100 | def geopoint(self, S, T, t): 101 | """Evaluate the geodesic from S to T at time t in [0, 1]""" 102 | assert S.shape == T.shape and np.isscalar(t) 103 | 104 | return self.exp(S, t * self.log(S, T)) 105 | 106 | def rand(self): 107 | S = np.random.random((self._k, self._d, self._d)) 108 | return np.einsum("...ij,...kj", S, S) 109 | 110 | def randvec(self, X): 111 | Y = self.rand() 112 | y = self.log(X, Y) 113 | return y / self.norm(X, y) 114 | 115 | def zerovec(self, X): 116 | return np.zeros((self._k, self._d, self._d)) 117 | 118 | def transp(self, S, T, X): 119 | """Parallel transport for Sym+(d)^k. 120 | :param S: element of Symp+(d)^k 121 | :param T: element of Symp+(d)^k 122 | :param X: tangent vector at S 123 | :return: parallel transport of X to the tangent space at T 124 | """ 125 | assert S.shape == T.shape == X.shape 126 | 127 | # if X were not in algebra but at tangent space at S 128 | # return dexp(log_mat(T), dlog(S, X)) 129 | 130 | return X 131 | 132 | def eleminner(self, R, X, Y): 133 | """element-wise inner product""" 134 | return np.einsum("...ij,...ij", X, Y) 135 | 136 | def elemnorm(self, R, X): 137 | """element-wise norm""" 138 | return np.sqrt(self.eleminner(R, X, X)) 139 | 140 | def projToGeodesic(self, X, Y, P, max_iter=10): 141 | """ 142 | :arg X, Y: elements of Symp+(d)^k defining geodesic X->Y. 143 | :arg P: element of Symp+(d)^k to be projected to X->Y. 144 | :returns: projection of P to X->Y 145 | """ 146 | 147 | assert X.shape == Y.shape 148 | assert Y.shape == P.shape 149 | 150 | # all tagent vectors in common space i.e. algebra 151 | v = self.log(X, Y) 152 | v /= self.norm(X, v) 153 | 154 | w = self.log(X, P) 155 | d = self.inner(X, v, w) 156 | 157 | return self.exp(X, d * v) 158 | 159 | def pairmean(self, S, T): 160 | assert S.shape == T.shape 161 | 162 | return self.exp(S, 0.5 * self.log(S, T)) 163 | 164 | def dist(self, S, T): 165 | """Distance function in Sym+(d)^k""" 166 | return self.norm(S, self.log(S, T)) 167 | 168 | def adjJacobi(self, S, T, t, X): 169 | """Evaluates an adjoint Jacobi field along the geodesic gam from S to T 170 | :param S: element of the space of differential coordinates 171 | :param T: element of the space of differential coordinates 172 | :param t: scalar in [0,1] 173 | :param X: tangent vector at gam(t) 174 | :return: tangent vector at X 175 | """ 176 | assert S.shape == T.shape == X.shape and np.isscalar(t) 177 | 178 | U = self.geopoint(S, T, t) 179 | return (1 - t) * self.transp(U, S, X) 180 | 181 | def adjDxgeo(self, S, T, t, X): 182 | """Evaluates the adjoint of the differential of the geodesic gamma from S to T w.r.t the starting point S at X, 183 | i.e, the adjoint of d_S gamma(t; ., T) applied to X, which is en element of the tangent space at gamma(t). 184 | """ 185 | assert S.shape == T.shape == X.shape and np.isscalar(t) 186 | 187 | return self.adjJacobi(S, T, t, X) 188 | 189 | def adjDygeo(self, S, T, t, X): 190 | """Evaluates the adjoint of the differential of the geodesic gamma from S to T w.r.t the endpoint T at X, 191 | i.e, the adjoint of d_T gamma(t; S, .) applied to X, which is en element of the tangent space at gamma(t). 192 | """ 193 | assert S.shape == T.shape == X.shape and np.isscalar(t) 194 | 195 | return self.adjJacobi(T, S, 1 - t, X) 196 | 197 | 198 | def log_mat(U): 199 | """Matrix logarithm, only use for normal matrices U, i.e., U * U^T = U^T * U""" 200 | vals, vecs = la.eigh(U) 201 | vals = np.log(np.where(vals > 1e-10, vals, 1)) 202 | return np.real(np.einsum("...ij,...j,...kj", vecs, vals, vecs)) 203 | 204 | 205 | def dexp(X, G): 206 | """Evaluate the derivative of the matrix exponential at 207 | X in direction G. 208 | """ 209 | return np.array([expm_frechet(X[i], G[i])[1] for i in range(X.shape[0])]) 210 | 211 | 212 | def dlog(X, G): 213 | """Evaluate the derivative of the matrix logarithm at 214 | X in direction G. 215 | """ 216 | n = X.shape[1] 217 | # set up [[X, G], [0, X]] 218 | W = np.hstack((np.dstack((X, G)), np.dstack((np.zeros_like(X), X)))) 219 | return np.array([logm(W[i])[:n, n:] for i in range(X.shape[0])]) 220 | 221 | 222 | def vectime3d(x, A): 223 | """ 224 | :param x: vector of length k 225 | :param A: array of size k x n x m 226 | :return: k x n x m array such that the j-th n x m slice of A is multiplied with the j-th element of x 227 | """ 228 | assert np.size(x.shape[0]) == 2 and np.size(A) == 3 229 | assert x.shape[0] == 1 or x.shape[1] == 1 230 | assert x.shape[0] == A.shape[0] or x.shape[1] == A.shape[0] 231 | 232 | if x.shape[0] == 1: 233 | x = x.T 234 | A = np.einsum("kij->ijk", A) 235 | return np.einsum("ijk->kij", x * A) 236 | 237 | 238 | def vectime3dB(x, A): 239 | """ 240 | :param x: vector of length k 241 | :param A: array of size k x n x m 242 | :return: k x n x m array such that the j-th n x m slice of A is multiplied with the j-th element of x 243 | 244 | In case of k=1, x * A is returned. 245 | """ 246 | if np.isscalar(x) and A.ndim == 2: 247 | return x * A 248 | 249 | x = np.atleast_2d(x) 250 | assert x.ndim <= 2 and np.size(A.shape) == 3 251 | assert x.shape[0] == 1 or x.shape[1] == 1 252 | assert x.shape[0] == A.shape[0] or x.shape[1] == A.shape[0] 253 | 254 | if x.shape[1] == 1: 255 | x = x.T 256 | A = np.einsum("kij->ijk", A) 257 | return np.einsum("ijk->kij", x * A) 258 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import os.path as osp 5 | import pickle 6 | import sys 7 | from datetime import datetime 8 | from typing import List 9 | 10 | import nni 11 | import numpy as np 12 | import torch 13 | from imblearn.over_sampling import RandomOverSampler 14 | from sklearn.model_selection import KFold 15 | from sklearn.utils import class_weight 16 | from torch.optim.lr_scheduler import ReduceLROnPlateau 17 | from torch_geometric.loader import DataLoader 18 | from torch_geometric.logging import init_wandb 19 | 20 | from src.dataset import BrainDataset 21 | from src.utils.data_utils import create_features, get_x, get_y 22 | from src.utils.explain import explain 23 | from src.utils.get_transform import get_transform 24 | from src.utils.model_utils import build_model 25 | from src.utils.modified_args import ModifiedArgs 26 | from src.utils.sample_selection import select_samples_per_class 27 | from src.utils.save_model import save_model 28 | from src.utils.train_and_evaluate import test, train_eval 29 | 30 | logging.basicConfig() 31 | logger = logging.getLogger() 32 | logger.setLevel(logging.INFO) 33 | 34 | 35 | class xGW_GAT: 36 | def main(self): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--dataset", type=str, default="PRIVATE") 39 | parser.add_argument( 40 | "--model_name", type=str, default="gatv2", choices=["gcn", "gatv2"] 41 | ) 42 | parser.add_argument("--num_classes", type=int, default=4) 43 | parser.add_argument( 44 | "--node_features", 45 | type=str, 46 | default="adj", 47 | choices=[ 48 | "identity", 49 | "degree", 50 | "degree_bin", 51 | "LDP", 52 | "node2vec", 53 | "adj", 54 | "diff_matrix", 55 | "eigenvector", 56 | "eigen_norm", 57 | ], 58 | ) 59 | parser.add_argument( 60 | "--centrality_measure", 61 | type=str, 62 | default="node", 63 | choices=[ 64 | "abs", 65 | "geo", 66 | "tan", 67 | "node", 68 | "eigen", 69 | "close", 70 | "concat_orig", 71 | "concat_scale", 72 | ], 73 | help="Chooses the topological measure to be used", 74 | ) 75 | parser.add_argument("--epochs", type=int, default=200) 76 | parser.add_argument("--lr", type=float, default=3e-4) 77 | parser.add_argument("--weight_decay", type=float, default=2e-2) 78 | parser.add_argument( 79 | "--gcn_mp_type", 80 | type=str, 81 | default="node_concat", 82 | choices=[ 83 | "weighted_sum", 84 | "bin_concat", 85 | "edge_weight_concat", 86 | "edge_node_concat", 87 | "node_concat", 88 | ], 89 | ) 90 | parser.add_argument( 91 | "--gat_mp_type", 92 | type=str, 93 | default="attention_weighted", 94 | choices=[ 95 | "attention_weighted", 96 | "attention_edge_weighted", 97 | "sum_attention_edge", 98 | "edge_node_concat", 99 | "node_concat", 100 | ], 101 | ) 102 | parser.add_argument( 103 | "--pooling", type=str, choices=["sum", "concat", "mean"], default="concat" 104 | ) 105 | parser.add_argument("--n_GNN_layers", type=int, default=2) 106 | parser.add_argument("--n_MLP_layers", type=int, default=2) 107 | parser.add_argument("--num_heads", type=int, default=4) 108 | parser.add_argument("--hidden_dim", type=int, default=8) 109 | parser.add_argument("--edge_emb_dim", type=int, default=1) 110 | parser.add_argument("--bucket_sz", type=float, default=0.05) 111 | parser.add_argument("--dropout", type=float, default=0.1) 112 | parser.add_argument("--repeat", type=int, default=1) 113 | parser.add_argument("--k_fold_splits", type=int, default=4) 114 | parser.add_argument("--k_list", type=list, default=[4]) 115 | parser.add_argument("--n_select_splits", type=int, default=2) 116 | parser.add_argument("--test_interval", type=int, default=10) 117 | parser.add_argument("--train_batch_size", type=int, default=2) 118 | parser.add_argument("--test_batch_size", type=int, default=1) 119 | parser.add_argument("--seed", type=int, default=112078) 120 | parser.add_argument("--diff", type=float, default=0.2) 121 | parser.add_argument("--mixup", type=int, default=1, choices=[0, 1]) 122 | parser.add_argument("--sample_selection", action="store_true") 123 | parser.add_argument("--enable_nni", action="store_true") 124 | parser.add_argument("--explain", action="store_true") 125 | parser.add_argument("--wandb", action="store_true", help="Track experiment") 126 | parser.add_argument("--log_result", action="store_true") 127 | parser.add_argument("--data_folder", type=str, default="datasets/") 128 | args = parser.parse_args() 129 | 130 | self_dir = os.path.dirname(os.path.realpath(__file__)) 131 | root_dir = osp.join(self_dir, args.data_folder) 132 | dataset = BrainDataset( 133 | root=root_dir, 134 | name=args.dataset, 135 | pre_transform=get_transform(args.node_features), 136 | num_classes=args.num_classes, 137 | ) 138 | args.num_nodes = dataset.num_nodes 139 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 140 | 141 | init_wandb( 142 | name=f"{args.model_name}-{args.dataset}", 143 | heads=args.num_heads, 144 | epochs=args.epochs, 145 | hidden_channels=args.hidden_dim, 146 | node_features=args.node_features, 147 | lr=args.lr, 148 | weight_decay=args.weight_decay, 149 | num_classes=args.num_classes, 150 | device=args.device, 151 | ) 152 | 153 | if args.enable_nni: 154 | args = ModifiedArgs(args, nni.get_next_parameter()) 155 | 156 | # init model 157 | model_name = str(args.model_name).lower() 158 | args.model_name = model_name 159 | 160 | y = get_y(dataset) 161 | connectomes = get_x(dataset).T 162 | 163 | class_weights = class_weight.compute_class_weight( 164 | class_weight="balanced", classes=np.unique(y), y=y 165 | ) 166 | class_weights = torch.tensor(class_weights, dtype=torch.float).to(args.device) 167 | 168 | test_accs, test_aucs, preds_all, labels_all = ( 169 | [], 170 | [], 171 | [], 172 | [], 173 | ) 174 | 175 | if args.sample_selection: 176 | # Check if node centrality features and subject labels exist 177 | if os.path.exists( 178 | f"{args.data_folder}data_dict_{args.node_features}_{args.num_classes}.pkl" 179 | ): 180 | with open( 181 | f"{args.data_folder}data_dict_{args.node_features}_{args.num_classes}.pkl", 182 | "rb", 183 | ) as d_d: 184 | data_dict = pickle.load(d_d) 185 | with open( 186 | f"{args.data_folder}score_dict_{args.node_features}_{args.num_classes}.pkl", 187 | "rb", 188 | ) as s_d: 189 | score_dict = pickle.load(s_d) 190 | else: # Create node centrality features and subject labels 191 | data_dict, score_dict = create_features( 192 | connectomes.numpy(), y, args, args.centrality_measure 193 | ) 194 | with open( 195 | f"{args.data_folder}data_dict_{args.node_features}_{args.num_classes}.pkl", 196 | "wb", 197 | ) as d_d: 198 | pickle.dump(data_dict, d_d) 199 | with open( 200 | f"{args.data_folder}score_dict_{args.node_features}_{args.num_classes}.pkl", 201 | "wb", 202 | ) as s_d: 203 | pickle.dump(score_dict, s_d) 204 | 205 | fold = -1 206 | 207 | for train_idx, test_idx in KFold( 208 | args.k_fold_splits, 209 | shuffle=True, 210 | random_state=args.seed, 211 | ).split(dataset): 212 | fold += 1 213 | print(f"Cross Validation Fold {fold+1}/{args.k_fold_splits}") 214 | 215 | if args.sample_selection: 216 | # Select top-k subjects with highest predictive power for labels 217 | sample_atlas = select_samples_per_class( 218 | train_idx, 219 | args.n_select_splits, 220 | args.k_list, 221 | data_dict, 222 | score_dict, 223 | y, 224 | shuffle=True, 225 | rs=args.seed, 226 | ) 227 | 228 | for k in args.k_list: 229 | if args.sample_selection: 230 | selected_train_idxs = np.array( 231 | [ 232 | sample_idx 233 | for class_samples in sample_atlas.values() 234 | for sample_indices in class_samples.values() 235 | for sample_idx in sample_indices 236 | ] 237 | ) 238 | else: 239 | selected_train_idxs = np.array(train_idx) 240 | 241 | # Apply RandomOverSampler to balance classes 242 | train_res_idxs, _ = RandomOverSampler().fit_resample( 243 | selected_train_idxs.reshape(-1, 1), 244 | [y[i] for i in selected_train_idxs], 245 | ) 246 | 247 | train_set = [dataset[i] for i in train_res_idxs.ravel()] 248 | test_set = [dataset[i] for i in test_idx] 249 | train_loader = DataLoader( 250 | train_set, batch_size=args.train_batch_size, shuffle=True 251 | ) 252 | test_loader = DataLoader( 253 | test_set, batch_size=args.test_batch_size, shuffle=False 254 | ) 255 | 256 | model = build_model(args, dataset.num_features) 257 | model = model.to(args.device) 258 | optimizer = torch.optim.AdamW( 259 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay 260 | ) 261 | scheduler = ReduceLROnPlateau( 262 | optimizer, mode="min", factor=0.5, patience=5, verbose=True 263 | ) 264 | 265 | train_acc, train_auc, train_model = train_eval( 266 | model, 267 | optimizer, 268 | scheduler, 269 | class_weights, 270 | args, 271 | train_loader, 272 | test_loader, 273 | ) 274 | 275 | save_model( 276 | args.epochs, train_model, optimizer, args 277 | ) # save trained model 278 | 279 | # test the best epoch saved model 280 | best_model_cp = torch.load( 281 | f"model_checkpoints/best_model_{args.model_name}_{args.num_classes}.pth" 282 | ) 283 | model.load_state_dict(best_model_cp["model_state_dict"]) 284 | 285 | test_acc, test_auc, t_preds, t_labels = test(model, test_loader, args) 286 | 287 | logging.info( 288 | f"(Performance Last Epoch) | test_acc={(test_acc * 100):.2f}, " 289 | + f"test_auc={(test_auc * 100):.2f}" 290 | ) 291 | test_accs.append(test_acc) 292 | test_aucs.append(test_auc) 293 | preds_all.append(t_preds) 294 | labels_all.append(t_labels) 295 | 296 | if args.explain: 297 | explain(model, test_loader, args) 298 | 299 | # Store predictions and targets 300 | curr_dt = str(datetime.now()) 301 | tag = "multi" if args.num_classes > 2 else "binary" 302 | saved_results = {} 303 | saved_results["preds"] = np.array(preds_all, dtype="object") 304 | saved_results["labels"] = np.array(labels_all, dtype="object") 305 | np.savez( 306 | f"./outputs/results/{curr_dt}_{args.model_name}_{args.node_features}_{tag}", 307 | **saved_results, 308 | ) 309 | 310 | result_str = ( 311 | f"(K Fold Final Result)| avg_acc={(np.mean(test_accs) * 100):.2f} +- {(np.std(test_accs) * 100): .2f}, " 312 | f"avg_auc={(np.mean(test_aucs) * 100):.2f} +- {np.std(test_aucs) * 100:.2f}\n" 313 | ) 314 | logging.info(result_str) 315 | 316 | with open( 317 | f"./outputs/logs/{curr_dt}_{args.model_name}_{args.node_features}_{tag}.log", 318 | "a", 319 | ) as f: 320 | # Write all input arguments to f 321 | input_arguments: List[str] = sys.argv 322 | f.write(f"{input_arguments}\n") 323 | f.write(result_str + "\n") 324 | if args.enable_nni: 325 | nni.report_final_result(np.mean(test_aucs)) 326 | 327 | 328 | if __name__ == "__main__": 329 | xGW_GAT().main() 330 | -------------------------------------------------------------------------------- /src/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for data processing. 3 | """ 4 | import copy 5 | import csv 6 | import math 7 | import os 8 | from typing import List 9 | 10 | import networkx as nx 11 | import numpy as np 12 | import pandas as pd 13 | import scipy 14 | import torch 15 | import torch_geometric 16 | from scipy.sparse import coo_matrix 17 | from torch import Tensor 18 | from torch_geometric.data import Data 19 | from torch_geometric.utils import degree 20 | from tqdm import tqdm 21 | 22 | from src.utils import SPD 23 | 24 | 25 | def get_x(dataset: List[Data]): 26 | """ 27 | Get the y values from a list of Data objects. 28 | """ 29 | x = [] 30 | for d in dataset: 31 | x.append(d.x) 32 | return torch.stack(x) 33 | 34 | 35 | def get_y(dataset: List[Data]): 36 | """ 37 | Get the y values from a list of Data objects. 38 | """ 39 | y = [] 40 | for d in dataset: 41 | y.append(d.y.item()) 42 | return np.array(y) 43 | 44 | 45 | def calculate_bin_edges(dataset: List[Data], num_bins: int = 10) -> Tensor: 46 | """ 47 | Calculate the bin edges for a given edge attribute tensor. 48 | :param dataset: The dataset to calculate the bin edges for. 49 | :param num_bins: The number of bins. 50 | :return: The bin edges. 51 | """ 52 | all_edges = np.concatenate([data.edge_attr.numpy() for data in dataset]) 53 | _, bin_edges = np.histogram(all_edges, bins=num_bins) 54 | bin_edges = np.sort(bin_edges) 55 | return Tensor(bin_edges) 56 | 57 | 58 | def create_features(data, score, args, method="eigen"): 59 | """Given data matrix and score vector, creates and saves 60 | the dictionaries for pairwise similarity features. 61 | Possible values for method: 62 | 'abs': absolute differences 63 | 'geo': geometric distance 64 | 'tan': vectorized tangent matrix 65 | 'node': node degree centrality 66 | 'eigen': eigenvector centrality 67 | 'close': closeness centrality 68 | """ 69 | 70 | data_dict = {} 71 | score_dict = {} 72 | spd = SPD.SPD(args.num_nodes) 73 | print("Starting topological feature extraction...") 74 | for i in tqdm(range(data.shape[2])): 75 | for j in range(i + 1, data.shape[2]): 76 | if method == "abs": 77 | dist = np.abs(data[:, :, i] - data[:, :, j]) 78 | dist = dist[np.triu_indices_from(dist, k=1)] 79 | if method == "geo": 80 | dist = spd.dist(data[:, :, i], data[:, :, j]) 81 | if method in ( 82 | "tan", 83 | "node", 84 | "eigen", 85 | "close", 86 | "concat_orig", 87 | "concat_scale", 88 | ): 89 | dist = spd.transp( 90 | data[:, :, 0] + np.eye(args.num_nodes) * 1e-10, 91 | data[:, :, i] + np.eye(args.num_nodes) * 1e-10, 92 | spd.log( 93 | data[:, :, i] + np.eye(args.num_nodes) * 1e-10, 94 | data[:, :, j] + np.eye(args.num_nodes) * 1e-10, 95 | ), 96 | ) 97 | if method == "tan": 98 | dist = dist[np.triu_indices_from(dist)] 99 | if method == "node": 100 | dist = nx.convert_matrix.from_numpy_matrix(np.abs(dist)) 101 | dist = np.array(list(dict(dist.degree(weight="weight")).values())) 102 | if method == "eigen": 103 | dist = nx.convert_matrix.from_numpy_matrix(np.abs(dist)) 104 | dist = np.array( 105 | list( 106 | dict( 107 | nx.eigenvector_centrality(dist, weight="weight") 108 | ).values() 109 | ) 110 | ) 111 | if method == "close": 112 | dist = nx.convert_matrix.from_numpy_matrix(np.abs(dist)) 113 | dist = np.array( 114 | list( 115 | dict( 116 | nx.closeness_centrality(dist, distance="weight") 117 | ).values() 118 | ) 119 | ) 120 | if method in ("concat_orig", "concat_scale"): 121 | dist = nx.convert_matrix.from_numpy_matrix(np.abs(dist)) 122 | dist_a = np.array(list(dict(dist.degree(weight="weight")).values())) 123 | dist_b = np.array( 124 | list( 125 | dict( 126 | nx.eigenvector_centrality(dist, weight="weight") 127 | ).values() 128 | ) 129 | ) 130 | dist_c = np.array( 131 | list( 132 | dict( 133 | nx.closeness_centrality(dist, distance="weight") 134 | ).values() 135 | ) 136 | ) 137 | dist = (dist_a, dist_b, dist_c) 138 | data_dict[(i, j)] = data_dict[(j, i)] = dist 139 | score_dict[(i, j)] = score_dict[(j, i)] = np.abs(score[i] - score[j]) 140 | if method == "concat_orig": 141 | dicts = data_dict 142 | data_dict = {} 143 | for key in dicts.keys(): 144 | data_dict[key] = np.concatenate(dicts[key]) 145 | if method == "concat_scale": 146 | dicts = data_dict 147 | lists = np.array([[x[0], x[1], x[2]] for x in dicts.values()]) 148 | list_a = lists[:, 0] 149 | list_b = lists[:, 1] 150 | list_c = lists[:, 2] 151 | max_a, min_a = np.max(list_a, axis=0), np.min(list_a, axis=0) 152 | max_b, min_b = np.max(list_b, axis=0), np.min(list_b, axis=0) 153 | max_c, min_c = np.max(list_c, axis=0), np.min(list_c, axis=0) 154 | diff_a = max_a - min_a 155 | diff_b = max_b - min_b 156 | diff_c = max_c - min_c 157 | data_dict = {} 158 | for key in dicts.keys(): 159 | a, b, c = dicts[key] 160 | data_dict[key] = np.concatenate( 161 | ((a - min_a) / diff_a, (b - min_b) / diff_b, (c - min_c) / diff_c) 162 | ) 163 | 164 | return data_dict, score_dict 165 | 166 | 167 | def load_scores(n_classes=4): 168 | # Handle gait scores loading for multi/binary classification 169 | if n_classes == 4: 170 | print("---------------------------------") 171 | print("Experiment: Multi-classification") 172 | print("---------------------------------") 173 | scores = torch.load(f"{Config.DATA_FOLDER}scores_multi.npy") 174 | elif n_classes == 2: 175 | # Binary classification 176 | print("---------------------------------") 177 | print("Experiment: Binary classification") 178 | print("---------------------------------") 179 | scores = torch.load(f"{Config.DATA_FOLDER}scores_binary.npy") 180 | else: 181 | raise Exception( 182 | f"Invalid number of classes, can't load ratings. Expected n_classes=2 or 4, but got {n_classes}." 183 | ) 184 | return scores 185 | 186 | 187 | def load_dataset_pytorch(n_classes=4): 188 | """Loads the data for the given population into a list of Pytorch Geometric 189 | Data objects, which then can be used to create DataLoaders. 190 | """ 191 | connectomes = torch.load(f"{Config.DATA_FOLDER}connectomes.npy") 192 | scores = load_scores(n_classes) 193 | 194 | # Filter out empty connectomes 195 | connectomes[connectomes < 0] = 0 196 | 197 | pyg_data = [] 198 | for subject in range(scores.shape[0]): 199 | sparse_mat = to_sparse(connectomes[:, :, subject]) 200 | pyg_data.append( 201 | torch_geometric.data.Data( 202 | x=torch.eye(Config.ROI, dtype=torch.float), 203 | y=scores[subject].float(), 204 | edge_index=sparse_mat._indices(), 205 | edge_attr=sparse_mat._values().float(), 206 | ) 207 | ) 208 | # edge_index, edge_attr = get_graph(connectomes[:, :, subject]) 209 | # pyg_data.append(torch_geometric.data.Data(x=torch.eye(Config.ROI, dtype=torch.float), 210 | # y=scores[subject].float(), edge_index=edge_index.float(), 211 | # edge_attr=edge_attr.float())) 212 | return pyg_data 213 | 214 | 215 | def get_graph(mat): 216 | """ 217 | Get edge_index and edge_attribute from a graph 218 | :param x: (T, C) 219 | :param weighted: True or False 220 | :return: edge_index: (2, num_edges) 221 | edge_weight:(num_edges, 1) 222 | """ 223 | G = nx.from_numpy_matrix(mat.numpy(), create_using=nx.Graph) 224 | edge_index = torch.tensor(list(G.edges)) 225 | # print(edge_index.shape) 226 | edge_attribute = [] 227 | for node1, node2, data in G.edges(data=True): 228 | edge_attribute.append(data["weight"]) 229 | edge_attribute = torch.tensor(edge_attribute) 230 | # print(edge_attribute.shape) 231 | return edge_index, edge_attribute 232 | 233 | 234 | def to_sparse(mat): 235 | """Transforms a square matrix to torch.sparse tensor 236 | 237 | Methods ._indices() and ._values() can be used to access to 238 | edge_index and edge_attr while generating Data objects 239 | """ 240 | coo = coo_matrix(mat, dtype="float64") 241 | row = torch.from_numpy(coo.row.astype(np.int64)) 242 | col = torch.from_numpy(coo.col.astype(np.int64)) 243 | coo_index = torch.stack([row, col], dim=0) 244 | coo_values = torch.from_numpy(coo.data.astype(np.float64).reshape(-1, 1)).reshape( 245 | -1 246 | ) 247 | sparse_mat = torch.sparse.LongTensor(coo_index, coo_values) 248 | return sparse_mat 249 | 250 | 251 | def load_dataset_cpm(n_classes=4): 252 | """Loads the data for given population in the upper triangular matrix form 253 | as required by CPM functions. 254 | """ 255 | connectomes = np.array(torch.load(f"{Config.DATA_FOLDER}connectomes.npy")) 256 | scores = load_scores(n_classes) 257 | 258 | fc_data = {} 259 | behav_data = {} 260 | for subject in range(scores.shape[0]): # take upper triangular part of each matrix 261 | fc_data[subject] = connectomes[:, :, subject][ 262 | np.triu_indices_from(connectomes[:, :, subject], k=1) 263 | ] 264 | behav_data[subject] = {"score": scores[subject].item()} 265 | return pd.DataFrame.from_dict(fc_data, orient="index"), pd.DataFrame.from_dict( 266 | behav_data, orient="index" 267 | ) 268 | 269 | 270 | def get_loaders(train, test, batch_size=1): 271 | """Returns data loaders for given data lists""" 272 | train_loader = torch_geometric.data.DataLoader(train, batch_size=batch_size) 273 | test_loader = torch_geometric.data.DataLoader(test, batch_size=batch_size) 274 | return train_loader, test_loader 275 | 276 | 277 | def to_dense(data): 278 | """Returns a copy of the data object in Dense form.""" 279 | denser = torch_geometric.transforms.ToDense() 280 | copy_data = denser(copy.deepcopy(data)) 281 | return copy_data 282 | 283 | 284 | def counts_from_cm(cm): 285 | """ 286 | Returns TP, FN FP, and TN for each class in the confusion matrix 287 | """ 288 | tp_all, fn_all, fp_all, tn_all = 0.0, 0.0, 0.0, 0.0 289 | 290 | for i in range(cm.shape[0]): 291 | tp = cm[0][i] 292 | 293 | fn_mask = np.zeros(cm.shape) 294 | fn_mask[i, :] = 1 295 | fn_mask[i, i] = 0 296 | fn = np.sum(np.multiply(cm, fn_mask)) 297 | 298 | fp_mask = np.zeros(cm.shape) 299 | fp_mask[:, i] = 1 300 | fp_mask[i, i] = 0 301 | fp = np.sum(np.multiply(cm, fp_mask)) 302 | 303 | tn_mask = 1 - (fn_mask + fp_mask) 304 | tn_mask[i, i] = 0 305 | tn = np.sum(np.multiply(cm, tn_mask)) 306 | 307 | tp_all += tp 308 | fn_all += fn 309 | fp_all += fp 310 | tn_all += tn 311 | return tp_all, fn_all, fp_all, tn_all 312 | 313 | 314 | def save_edges_to_mat(edges: np.ndarray, filename): 315 | edge_dict = {"edges": edges} 316 | scipy.io.savemat(filename, edge_dict) 317 | 318 | 319 | def edge_index_to_adj_matrix( 320 | edge_index: Tensor, edge_attr: Tensor, num_node: int 321 | ) -> np.ndarray: 322 | adj = np.zeros((num_node, num_node)) 323 | for i in range(edge_index.shape[1]): 324 | source = edge_index[0, i].item() 325 | target = edge_index[1, i].item() 326 | adj[source, target] = edge_attr[i].item() 327 | return adj 328 | 329 | 330 | def edges2adj(edges: Tensor, num_nodes: int = 0) -> np.ndarray: 331 | if num_nodes == 0: 332 | num_nodes = int(math.sqrt(edges.shape[0])) 333 | adj = np.zeros((num_nodes, num_nodes)) 334 | for index, edge in enumerate(edges): 335 | adj[index % num_nodes, int(index / num_nodes)] = edge 336 | return adj 337 | 338 | 339 | def count_degrees(train_dataset): 340 | max_degree = -1 341 | for data in train_dataset: 342 | if isinstance(data, list): 343 | g = data[0] 344 | else: 345 | g = data 346 | d = degree(g.edge_index[1], num_nodes=g.num_nodes, dtype=torch.long) 347 | max_degree = max(max_degree, int(d.max())) 348 | 349 | # Compute the in-degree histogram tensor 350 | deg = torch.zeros(max_degree + 1, dtype=torch.long) 351 | print("Computing degrees for PNA...") 352 | for data in train_dataset: 353 | if isinstance(data, list): 354 | g = data[0] 355 | else: 356 | g = data 357 | d = degree(g.edge_index[1], num_nodes=g.num_nodes, dtype=torch.long) 358 | deg += torch.bincount(d, minlength=deg.numel()) 359 | 360 | return deg 361 | 362 | 363 | def generate_full_edges(num_nodes) -> Tensor: 364 | full_edge_index = np.zeros((2, num_nodes * num_nodes), dtype=np.long) 365 | 366 | for source in range(0, num_nodes): 367 | for target in range(0, num_nodes): 368 | row = source * num_nodes + target 369 | full_edge_index[0, row] = source 370 | full_edge_index[1, row] = target 371 | 372 | full_edge_index_tensor = torch.LongTensor(full_edge_index) 373 | return full_edge_index_tensor 374 | 375 | 376 | def map_edges_attr(attr: Tensor, edge_index: Tensor, num_nodes: int) -> Tensor: 377 | new_edge_attrs = np.zeros((num_nodes * num_nodes,)) 378 | for i in range(attr.shape[0]): 379 | source = edge_index[0, i] 380 | target = edge_index[1, i] 381 | 382 | # maps edge attr to new index 383 | new_index = source * num_nodes + target 384 | new_edge_attrs[new_index] = attr[i].item() 385 | 386 | return Tensor(new_edge_attrs) 387 | 388 | 389 | def generate_community_labels_for_edges( 390 | edge_index: Tensor, node_labels: List[int] 391 | ) -> Tensor: 392 | edge_count = edge_index.shape[1] 393 | edge_community_label = np.zeros(edge_count) 394 | for row in range(edge_count): 395 | source: Tensor = edge_index[0, row] 396 | target: Tensor = edge_index[1, row] 397 | if node_labels[source.item()] == node_labels[target.item()]: 398 | # If source.label == target.label, 399 | # then set the corresponding community label to 1 400 | edge_community_label[row] = 1 401 | 402 | return Tensor(edge_community_label) 403 | -------------------------------------------------------------------------------- /community_networks/roi_coords.csv: -------------------------------------------------------------------------------- 1 | ,Grey Matter,atlas.FP r (Frontal Pole Right),atlas.FP l (Frontal Pole Left),atlas.IC r (Insular Cortex Right),atlas.IC l (Insular Cortex Left),atlas.SFG r (Superior Frontal Gyrus Right),atlas.SFG l (Superior Frontal Gyrus Left),atlas.MidFG r (Middle Frontal Gyrus Right),atlas.MidFG l (Middle Frontal Gyrus Left),"atlas.IFG tri r (Inferior Frontal Gyrus, pars triangularis Right)","atlas.IFG tri l (Inferior Frontal Gyrus, pars triangularis Left)","atlas.IFG oper r (Inferior Frontal Gyrus, pars opercularis Right)","atlas.IFG oper l (Inferior Frontal Gyrus, pars opercularis Left)",atlas.PreCG r (Precentral Gyrus Right),atlas.PreCG l (Precentral Gyrus Left),atlas.TP r (Temporal Pole Right),atlas.TP l (Temporal Pole Left),"atlas.aSTG r (Superior Temporal Gyrus, anterior division Right)","atlas.aSTG l (Superior Temporal Gyrus, anterior division Left)","atlas.pSTG r (Superior Temporal Gyrus, posterior division Right)","atlas.pSTG l (Superior Temporal Gyrus, posterior division Left)","atlas.aMTG r (Middle Temporal Gyrus, anterior division Right)","atlas.aMTG l (Middle Temporal Gyrus, anterior division Left)","atlas.pMTG r (Middle Temporal Gyrus, posterior division Right)","atlas.pMTG l (Middle Temporal Gyrus, posterior division Left)","atlas.toMTG r (Middle Temporal Gyrus, temporooccipital part Right)","atlas.toMTG l (Middle Temporal Gyrus, temporooccipital part Left)","atlas.aITG r (Inferior Temporal Gyrus, anterior division Right)","atlas.aITG l (Inferior Temporal Gyrus, anterior division Left)","atlas.pITG r (Inferior Temporal Gyrus, posterior division Right)","atlas.pITG l (Inferior Temporal Gyrus, posterior division Left)","atlas.toITG r (Inferior Temporal Gyrus, temporooccipital part Right)","atlas.toITG l (Inferior Temporal Gyrus, temporooccipital part Left)",atlas.PostCG r (Postcentral Gyrus Right),atlas.PostCG l (Postcentral Gyrus Left),atlas.SPL r (Superior Parietal Lobule Right),atlas.SPL l (Superior Parietal Lobule Left),"atlas.aSMG r (Supramarginal Gyrus, anterior division Right)","atlas.aSMG l (Supramarginal Gyrus, anterior division Left)","atlas.pSMG r (Supramarginal Gyrus, posterior division Right)","atlas.pSMG l (Supramarginal Gyrus, posterior division Left)",atlas.AG r (Angular Gyrus Right),atlas.AG l (Angular Gyrus Left),"atlas.sLOC r (Lateral Occipital Cortex, superior division Right)","atlas.sLOC l (Lateral Occipital Cortex, superior division Left)","atlas.iLOC r (Lateral Occipital Cortex, inferior division Right)","atlas.iLOC l (Lateral Occipital Cortex, inferior division Left)",atlas.ICC r (Intracalcarine Cortex Right),atlas.ICC l (Intracalcarine Cortex Left),atlas.MedFC (Frontal Medial Cortex),atlas.SMA r (Juxtapositional Lobule Cortex -formerly Supplementary Motor Cortex- Right),atlas.SMA L(Juxtapositional Lobule Cortex -formerly Supplementary Motor Cortex- Left),atlas.SubCalC (Subcallosal Cortex),atlas.PaCiG r (Paracingulate Gyrus Right),atlas.PaCiG l (Paracingulate Gyrus Left),"atlas.AC (Cingulate Gyrus, anterior division)","atlas.PC (Cingulate Gyrus, posterior division)",atlas.Precuneous (Precuneous Cortex),atlas.Cuneal r (Cuneal Cortex Right),atlas.Cuneal l (Cuneal Cortex Left),atlas.FOrb r (Frontal Orbital Cortex Right),atlas.FOrb l (Frontal Orbital Cortex Left),"atlas.aPaHC r (Parahippocampal Gyrus, anterior division Right)","atlas.aPaHC l (Parahippocampal Gyrus, anterior division Left)","atlas.pPaHC r (Parahippocampal Gyrus, posterior division Right)","atlas.pPaHC l (Parahippocampal Gyrus, posterior division Left)",atlas.LG r (Lingual Gyrus Right),atlas.LG l (Lingual Gyrus Left),"atlas.aTFusC r (Temporal Fusiform Cortex, anterior division Right)","atlas.aTFusC l (Temporal Fusiform Cortex, anterior division Left)","atlas.pTFusC r (Temporal Fusiform Cortex, posterior division Right)","atlas.pTFusC l (Temporal Fusiform Cortex, posterior division Left)",atlas.TOFusC r (Temporal Occipital Fusiform Cortex Right),atlas.TOFusC l (Temporal Occipital Fusiform Cortex Left),atlas.OFusG r (Occipital Fusiform Gyrus Right),atlas.OFusG l (Occipital Fusiform Gyrus Left),atlas.FO r (Frontal Operculum Cortex Right),atlas.FO l (Frontal Operculum Cortex Left),atlas.CO r (Central Opercular Cortex Right),atlas.CO l (Central Opercular Cortex Left),atlas.PO r (Parietal Operculum Cortex Right),atlas.PO l (Parietal Operculum Cortex Left),atlas.PP r (Planum Polare Right),atlas.PP l (Planum Polare Left),atlas.HG r (Heschl's Gyrus Right),atlas.HG l (Heschl's Gyrus Left),atlas.PT r (Planum Temporale Right),atlas.PT l (Planum Temporale Left),atlas.SCC r (Supracalcarine Cortex Right),atlas.SCC l (Supracalcarine Cortex Left),atlas.OP r (Occipital Pole Right),atlas.OP l (Occipital Pole Left),atlas.Thalamus r,atlas.Thalamus l,atlas.Caudate r,atlas.Caudate l,atlas.Putamen r,atlas.Putamen l,atlas.Pallidum r,atlas.Pallidum l,atlas.Hippocampus r,atlas.Hippocampus l,atlas.Amygdala r,atlas.Amygdala l,atlas.Accumbens r,atlas.Accumbens l,atlas.Brain-Stem,atlas.Cereb1 l (Cerebelum Crus1 Left),atlas.Cereb1 r (Cerebelum Crus1 Right),atlas.Cereb2 l (Cerebelum Crus2 Left),atlas.Cereb2 r (Cerebelum Crus2 Right),atlas.Cereb3 l (Cerebelum 3 Left),atlas.Cereb3 r (Cerebelum 3 Right),atlas.Cereb45 l (Cerebelum 4 5 Left),atlas.Cereb45 r (Cerebelum 4 5 Right),atlas.Cereb6 l (Cerebelum 6 Left),atlas.Cereb6 r (Cerebelum 6 Right),atlas.Cereb7 l (Cerebelum 7b Left),atlas.Cereb7 r (Cerebelum 7b Right),atlas.Cereb8 l (Cerebelum 8 Left),atlas.Cereb8 r (Cerebelum 8 Right),atlas.Cereb9 l (Cerebelum 9 Left),atlas.Cereb9 r (Cerebelum 9 Right),atlas.Cereb10 l (Cerebelum 10 Left),atlas.Cereb10 r (Cerebelum 10 Right),atlas.Ver12 (Vermis 1 2),atlas.Ver3 (Vermis 3),atlas.Ver45 (Vermis 4 5),atlas.Ver6 (Vermis 6),atlas.Ver7 (Vermis 7),atlas.Ver8 (Vermis 8),atlas.Ver9 (Vermis 9),atlas.Ver10 (Vermis 10),"networks.DefaultMode.MPFC (1,55,-3)","networks.DefaultMode.LP (L) (-39,-77,33)","networks.DefaultMode.LP (R) (47,-67,29)","networks.DefaultMode.PCC (1,-61,38)","networks.SensoriMotor.Lateral (L) (-55,-12,29)","networks.SensoriMotor.Lateral (R) (56,-10,29)","networks.SensoriMotor.Superior (0,-31,67)","networks.Visual.Medial (2,-79,12)","networks.Visual.Occipital (0,-93,-4)","networks.Visual.Lateral (L) (-37,-79,10)","networks.Visual.Lateral (R) (38,-72,13)","networks.Salience.ACC (0,22,35)","networks.Salience.AInsula (L) (-44,13,1)","networks.Salience.AInsula (R) (47,14,0)","networks.Salience.RPFC (L) (-32,45,27)","networks.Salience.RPFC (R) (32,46,27)","networks.Salience.SMG (L) (-60,-39,31)","networks.Salience.SMG (R) (62,-35,32)","networks.DorsalAttention.FEF (L) (-27,-9,64)","networks.DorsalAttention.FEF (R) (30,-6,64)","networks.DorsalAttention.IPS (L) (-39,-43,52)","networks.DorsalAttention.IPS (R) (39,-42,54)","networks.FrontoParietal.LPFC (L) (-43,33,28)","networks.FrontoParietal.PPC (L) (-46,-58,49)","networks.FrontoParietal.LPFC (R) (41,38,30)","networks.FrontoParietal.PPC (R) (52,-52,45)","networks.Language.IFG (L) (-51,26,2)","networks.Language.IFG (R) (54,28,1)","networks.Language.pSTG (L) (-57,-47,15)","networks.Language.pSTG (R) (59,-42,13)","networks.Cerebellar.Anterior (0,-63,-30)","networks.Cerebellar.Posterior (0,-79,-32)" 2 | x,0.8875634373039859,26.155852181436163,-24.723501642404553,37.384686603092305,-36.394252441773105,14.665833101787534,-14.06542177548633,39.11812950292265,-38.07063593683312,51.86739433348816,-49.7100250144314,52.21275436046512,-50.642787682333875,34.50456521104985,-33.72482422317986,40.63669918270498,-40.49289576927155,57.50133451957296,-56.17233661593554,61.34069303894511,-62.288503253796094,57.88912196583758,-57.46777316735823,61.07506628874463,-60.906415646134164,58.181207326324916,-57.639859113589665,46.228870605833954,-48.142014016967906,53.42195697796432,-53.443172023149856,54.141678565762334,-51.81781305114639,37.62464293985949,-38.41085324232082,29.210169491525424,-29.301115879828327,58.41554321966693,-56.80042519266543,55.20992099209921,-54.8832702671712,51.93327067669173,-50.352640979841794,32.969060371275546,-31.960864750591433,45.53747714808044,-45.12883995376848,11.664292812189467,-10.19761630142253,0.20741405849153413,5.9213389121338915,-5.37041788856305,-0.07427114510586409,6.554901098901099,-6.207925322595406,0.8027729802341201,0.7848450176825463,0.9578895661687543,8.849179716629381,-8.218706047819971,29.1139512866856,-29.542842369626236,22.358580106302202,-21.867960750853243,22.901229823212912,-21.895994832041342,13.56545584045584,-12.271574014221073,31.056416881998278,-31.882512601783638,36.27906976744186,-35.96667145525068,35.051491746175984,-33.49844840961986,27.258792209611883,-26.57573300905283,41.1186848436247,-39.703440936502304,49.4333100069979,-47.990138319672134,48.9025641025641,-48.356259833670485,48.001000667111406,-46.61085972850679,46.11156095366621,-45.196649381731156,54.96693046919163,-52.6993769470405,8.215702479338843,-8.36255572065379,17.72875074125321,-16.85335589396503,10.843621801133034,-9.993136072720526,13.3015606242497,-12.785498489425981,25.495748956562064,-24.90149125322627,19.850489054494645,-18.95779356456331,26.497066666666665,-25.17773788150808,23.0874861572536,-22.995011511895626,9.368263473053892,-9.463882618510159,0.48253182887829604,-36.469651107216826,37.64117073170732,-28.635843848580443,32.057459795750674,-8.800373134328359,11.98092105263158,-13.941201533873029,16.179533213644525,-22.829401977096996,24.461944173969492,-32.357835740461304,33.13924349881797,-25.75142478462558,25.06399563913873,-10.947429231658,9.460229031259672,-22.6142001710864,25.99453125,0.7574257425742574,1.380351262349067,1.2138716769938038,1.1722452049642722,1.145780051150895,1.152061855670103,0.8646671543525969,0.35583524027459956,1.112927191679049,-38.91066282420749,47.18401206636501,0.6302503620939375,-55.46723044397463,56.38646344613795,-0.07324259856882279,1.798646874684439,-0.21218590901625883,-36.83118556701031,38.394253414978806,0.38758231420507994,-43.68609865470852,46.54639175257732,-31.653516295025728,31.666092943201377,-59.57081545064378,61.859154929577464,-26.886363636363637,30.22222222222222,-38.62770167427702,38.95122728721709,-43.11685261303582,-46.23317307692308,40.767918088737204,52.35125448028674,-51.03319502074689,53.61638733705773,-57.025187803800264,58.55245998814463,0.23564695801199656,0.4031065088757396 3 | y,-24.19802702856817,52.14318378445186,52.95596919818348,2.5496713267290065,1.1867956423741548,18.425349634157637,18.68008318530393,18.623997462503965,18.439265898420828,27.76335346028797,28.497979603617473,15.41751453488372,14.517017828200972,-10.795921939266648,-11.826519618961216,12.956867914579488,11.102216095854098,-0.7633451957295374,-3.906445837063563,-23.98589389757743,-29.169197396963124,-1.5220257716511838,-4.205255878284924,-22.524549693700283,-27.358800950466094,-49.22239080958058,-52.99985324332257,-2.4098728496634254,-4.974548137218738,-23.46366736621196,-28.45720970323852,-49.87783595113438,-53.44356261022928,-26.374932448081527,-27.85986348122867,-47.777033898305085,-49.46892703862661,-27.06344171292625,-32.75139516343343,-40.365836583658364,-46.02946909758055,-51.80254613807245,-55.704899208981885,-71.11907857105714,-72.89364272411537,-73.94515539305301,-75.5447411643044,-73.57568731368002,-75.02556708958093,43.188045151359674,-2.7847698744769875,-2.778958944281525,20.535639064405277,36.56764835164835,36.65342729019859,18.293705622721166,-36.6217495319326,-59.287918369169624,-78.55219985085756,-80.28762306610408,23.070660125656254,23.662283941498007,-8.051442672741079,-9.108788395904437,-30.534204458109148,-32.42603359173127,-63.49145299145299,-65.67000646412411,-2.8070628768303187,-4.433889104303994,-24.144430844553245,-29.531245510702487,-50.06088141753748,-53.680954228083785,-75.39862687403671,-76.57343602215917,18.626704089815558,18.32600212841433,-5.77074877536739,-8.625128073770492,-27.641025641025642,-31.854124522364575,-3.584723148765844,-5.971806474068917,-17.404408457040034,-20.321499800558435,-25.071509327303563,-29.706052514463728,-74.4892561983471,-73.25111441307578,-95.13461158331685,-96.5014100394811,-18.321254151201405,-19.2239124385493,10.010804321728692,8.976992795723914,1.7760086566702735,0.48264984227129337,-4.005123428039124,-5.120351023819474,-20.958933333333334,-23.191610902562427,-3.985234403839055,-4.946661550268611,12.203592814371257,11.496613995485328,-29.726241477101308,-66.03260580956933,-66.95131707317073,-73.26031808622503,-69.01590562272567,-37.22294776119403,-34.90855263157895,-44.1981252662974,-44.270895671254735,-58.04345698345894,-57.53716325868225,-59.820219874973056,-63.17801418439716,-54.51908548707753,-56.33982011447261,-48.95046216060081,-49.49953574744661,-33.79982891360137,-33.83828125,-38.79207920792079,-39.93084522502744,-52.17889266440136,-66.45204964272283,-71.93030690537084,-64.42938144329896,-54.87490855888807,-45.79977116704806,55.326894502228825,-76.87223823246877,-66.60482654600301,-60.982826401820816,-12.36416490486258,-9.868251097907518,-30.53570927458959,-78.76158739775825,-92.77483987518475,-78.50837628865979,-71.98304286387187,22.118532455315147,12.798206278026905,14.237113402061855,45.054888507718694,46.406196213425126,-38.68669527896996,-35.03521126760563,-8.590909090909092,-6.37037037037037,-43.21400304414003,-41.76984379980873,33.18614210217264,-58.12019230769231,37.924914675767916,-52.0310633213859,25.813278008298756,28.37243947858473,-47.44851966416262,-42.02133965619443,-63.497000856898026,-78.94045857988165 4 | z,5.580566801619433,8.257817327895193,7.5081222738776905,-0.1737802055365244,0.08236288504883546,56.958368065203295,56.161994714267145,42.78576283474557,42.06419120785318,7.707385044124478,8.669424668077736,16.202398255813954,15.394327390599676,50.13351419153467,49.369896801995914,-29.621144213023992,-29.60465486162655,-10.165925266903914,-7.970008952551477,1.57344372891751,3.7964053300278895,-24.505843572070724,-22.13914246196404,-12.149858279235621,-10.996709925059404,1.5974856399696542,0.8241855004402701,-41.1058339566193,-39.19107340464773,-28.13352570828961,-25.987563107991626,-16.730763128668887,-16.5331569664903,52.63761290820659,51.66716723549488,58.90220338983051,57.468583690987124,37.812212529738304,37.19465851714058,33.6046604660466,33.241944731108816,32.35586124401914,29.762822148507272,38.924600904789145,37.97183772084361,-1.5775137111517368,-1.90029807165886,8.316826763829082,8.041906958861976,-18.512698819907644,57.54075313807532,56.07734604105572,-14.832612792373352,22.69432967032967,20.789237668161434,24.34508731529457,29.975088412731434,38.026994440351,27.88683818046234,27.14299109235818,-16.23143127635769,-16.57261042990102,-30.25607441154138,-30.301407849829353,-16.75903151421983,-16.885658914728683,-4.957763532763533,-5.438510019392372,-42.34065460809647,-41.90306320279178,-27.82951564958909,-25.079873581381985,-16.63849765258216,-15.974398758727697,-12.307692307692308,-13.588298878529928,4.913793103448276,4.526427811280596,11.131980405878236,11.813140368852459,21.54731934731935,20.46122724207687,-7.186457638425617,-7.337626174730247,6.967611336032388,7.192261667331472,12.066421707179197,10.789052069425901,14.081818181818182,14.796433878157504,8.310387428345523,6.737638653882309,6.624340691541317,6.294870605695205,10.490516206482592,9.73739251684871,0.3034472097696707,0.3395468884427875,-1.1891010712622263,-1.3338905139991641,-14.250133333333334,-13.80594091725151,-17.686969361387966,-17.731772831926325,-6.5344311377245505,-7.170428893905192,-34.97725215712303,-29.777069725452712,-29.80531707317073,-38.203732912723446,-39.94870289940134,-18.58115671641791,-19.26513157894737,-17.14131515409743,-18.573907839616997,-23.68111970245669,-24.512414800389482,-45.449450312567365,-48.45721040189125,-47.684625579854206,-49.46802943581357,-45.90294627383015,-46.326833797585884,-41.76475620188195,-41.346875,-20.04950495049505,-11.397914379802415,-6.6887867279632225,-15.84956750658142,-25.14066496163683,-34.079896907216494,-34.89612289685443,-31.683066361556065,-2.885586924219911,33.13544668587896,28.726998491704375,38.13697496379061,29.489957716701902,28.818393180056834,67.4050792759927,11.945269110370594,-3.8764986040400724,9.794458762886597,12.829486575600566,34.809031044214485,0.6816143497757847,0.030927835051546393,27.253859348198972,27.38382099827883,30.523605150214593,31.514084507042252,64.43181818181819,63.96296296296296,52.25814307458143,54.223780682180426,28.244274809160306,48.50240384615385,30.41638225255973,45.340501792114694,1.7385892116182573,1.1545623836126628,15.004860804242156,12.848844101956136,-30.303341902313626,-31.589497041420117 5 | --------------------------------------------------------------------------------