├── .gitignore ├── LICENSE ├── README.md ├── graph_constuction.py ├── layers.py ├── model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Repo-specific 2 | results/ 3 | data/ 4 | models/__pycache__/ 5 | .vscode/ 6 | logs/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 112 | __pypackages__/ 113 | 114 | # Celery stuff 115 | celerybeat-schedule 116 | celerybeat.pid 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | 148 | # pytype static type analyzer 149 | .pytype/ 150 | 151 | # Cython debug symbols 152 | cython_debug/ 153 | 154 | # PyCharm 155 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 156 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 157 | # and can be added to the global gitignore or merged into this file. For a more nuclear 158 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 159 | #.idea/ 160 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 bintsi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal brain age estimation using interpretable adaptive population-graph learning 2 | 3 | ## About 4 | This is a Pytorch Lightning implementation for the paper 5 | [Multimodal brain age estimation using interpretable adaptive population-graph learning](https://arxiv.org/abs/2307.04639) 6 | (MICCAI 2023) by Kyriaki-Margarita Bintsi, Vasileios Baltatzis, Rolandos Alexandros Potamias, Alexander Hammers, and Daniel Rueckert 7 | 8 | ## Requirements 9 | conda install -c anaconda cmake=3.19 10 | conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 cudatoolkit=10.1 -c pytorch 11 | pip install pytorch_lightning==1.3.8 12 | pip install torch-geometric 13 | 14 | ## Dataset 15 | The dataset used for this paper is the UK Biobank. Since the data is not public, we cannot share the csv files. 16 | You need to put the csv files in the data folder that is available. 17 | The format that the csvs need to have is the following: 18 | train.csv, val.csv, test.csv 19 | 20 | For every csv: 21 | Column 0: eid 22 | Column 1: label (age) 23 | Column 2-22: Non-imaging phenotypes 24 | Column 22-90: Imaging phenotypes 25 | 26 | ## Training 27 | To train a model for the hyperparameters chosen for the regression task run the following command: 28 | `python train.py` 29 | 30 | ## Reference 31 | If you find the code useful, pleace cite: 32 | ``` 33 | @article{bintsi2023multimodal, 34 | title={Multimodal brain age estimation using interpretable adaptive population-graph learning}, 35 | author={Bintsi, Kyriaki-Margarita and Baltatzis, Vasileios and Potamias, Rolandos Alexandros and Hammers, Alexander and Rueckert, Daniel}, 36 | journal={arXiv preprint arXiv:2307.04639}, 37 | year={2023} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /graph_constuction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | from torch_geometric.data import Data 5 | from torch_geometric.transforms import KNNGraph 6 | 7 | """ 8 | Population-graph construction. 9 | Takes the imaging and non-imaging data and creates the initial population graph that will be used in the network. 10 | """ 11 | 12 | class PopulationGraphUKBB: 13 | def __init__(self, data_dir, filename_train, filename_val, filename_test, phenotype_columns, columns_kept, num_node_features, task, num_classes, k, edges): 14 | self.data_dir = data_dir 15 | self.filename_train = filename_train 16 | self.filename_val = filename_val 17 | self.filename_test = filename_test 18 | self.phenotype_columns = phenotype_columns 19 | self.columns_kept = columns_kept 20 | self.num_node_features = num_node_features 21 | self.num_classes = num_classes 22 | self.k = k 23 | self.edges = edges 24 | self.task = task 25 | 26 | def load_data(self): 27 | """ 28 | Loads the dataframes for the train, val, and test, and returns 1 dataframe for all. 29 | """ 30 | 31 | # Read csvs for tran, val, test 32 | data_df_train = pd.read_csv(self.data_dir + self.filename_train) 33 | data_df_val = pd.read_csv(self.data_dir+self.filename_val) 34 | data_df_test = pd.read_csv(self.data_dir+self.filename_test) 35 | 36 | # Give labels for classification 37 | if self.task == 'classification': 38 | frames = [data_df_train, data_df_val, data_df_test] 39 | df = pd.concat(frames) 40 | 41 | labels = list(range(0,self.num_classes)) 42 | df['Age'] = pd.qcut(df['Age'], q=self.num_classes, labels=labels).astype('int') #Balanced classes 43 | # df['Age'] = pd.cut(df['Age'], bins=self.num_classes, labels=labels).astype('int') #Not balanced classes 44 | 45 | a = data_df_train.shape[0] 46 | b = data_df_val.shape[0] 47 | 48 | data_df_train = df.iloc[:a, :] 49 | data_df_val = df.iloc[a:a+b, :] 50 | data_df_test = df.iloc[a+b:, :] 51 | 52 | a = data_df_train.shape[0] 53 | b = data_df_train.shape[0]+data_df_val.shape[0] 54 | num_nodes = b + data_df_test.shape[0] 55 | 56 | train_idx = np.arange(0, a, dtype=int) 57 | val_idx = np.arange(a, b, dtype=int) 58 | test_idx = np.arange(b, num_nodes, dtype=int) 59 | frames = [data_df_train, data_df_val, data_df_test] 60 | 61 | data_df = pd.concat(frames, ignore_index=True) 62 | 63 | return data_df, train_idx, val_idx, test_idx, num_nodes 64 | 65 | def get_phenotypes(self, data_df): 66 | """ 67 | Takes the dataframe for the train, val, and test, and returns 1 dataframe with only the phenotypes. 68 | """ 69 | phenotypes_df = data_df[self.phenotype_columns] 70 | 71 | return phenotypes_df 72 | 73 | def get_features_demographics(self, phenotypes_df): 74 | """ 75 | Returns the phenotypes of every node, meaning for every subject. 76 | The node features are defined by the non-imaging information 77 | """ 78 | phenotypes = phenotypes_df.to_numpy() 79 | phenotypes = torch.from_numpy(phenotypes).float() 80 | return phenotypes 81 | 82 | def get_node_features(self, data_df): 83 | """ 84 | Returns the features of every node, meaning for every subject. 85 | """ 86 | df_node_features = data_df.iloc[:, 2:] 87 | node_features = df_node_features.to_numpy() 88 | node_features = torch.from_numpy(node_features).float() 89 | return node_features 90 | 91 | def get_subject_masks(self, train_index, validate_index, test_index): 92 | """Returns the boolean masks for the arrays of integer indices. 93 | 94 | inputs: 95 | train_index: indices of subjects in the train set. 96 | validate_index: indices of subjects in the validation set. 97 | test_index: indices of subjects in the test set. 98 | 99 | returns: 100 | a tuple of boolean masks corresponding to the train/validate/test set indices. 101 | """ 102 | 103 | num_subjects = len(train_index) + len(validate_index) + len(test_index) 104 | 105 | train_mask = np.zeros(num_subjects, dtype=bool) 106 | train_mask[train_index] = True 107 | train_mask = torch.from_numpy(train_mask) 108 | 109 | validate_mask = np.zeros(num_subjects, dtype=bool) 110 | validate_mask[validate_index] = True 111 | validate_mask = torch.from_numpy(validate_mask) 112 | 113 | test_mask = np.zeros(num_subjects, dtype=bool) 114 | test_mask[test_index] = True 115 | test_mask = torch.from_numpy(test_mask) 116 | 117 | return train_mask, validate_mask, test_mask 118 | 119 | def get_labels(self, data_df): 120 | """ 121 | Returns the labels for every node, in our case, age. 122 | 123 | """ 124 | if self.task == 'regression': 125 | labels = data_df['Age'].values 126 | labels = torch.from_numpy(labels).float() 127 | elif self.task == 'classification': 128 | labels = data_df['Age'].values 129 | print(np.unique(labels, return_counts=True)) 130 | labels = torch.from_numpy(labels) 131 | else: 132 | raise ValueError('Task should be either regression or classification.') 133 | return labels 134 | 135 | def get_edges_using_KNNgraph(self, dataset, k): 136 | """ 137 | Extracts edge index based on the cosine similarity of the node features. 138 | 139 | Inputs: 140 | dataset: the population graph (without edge_index). 141 | k: number of edges that will be kept for every node. 142 | 143 | Returns: 144 | dataset: graph dataset with the acquired edges. 145 | """ 146 | 147 | if self.edges == 'phenotypes': 148 | # Edges extracted based on the similarity of the selected phenotypes (imaging+non imaging) 149 | dataset.pos = dataset.phenotypes 150 | elif self.edges == 'imaging': 151 | # Edges extracted based on the similarity of the node features 152 | dataset.pos = dataset.x 153 | else: 154 | raise ValueError('Choose appropriate edge connection.') 155 | 156 | dataset.cuda() 157 | dataset = KNNGraph(k=k, force_undirected=True)(dataset) 158 | dataset.to('cpu') 159 | dataset = Data(x = dataset.x, y = dataset.y, phenotypes = dataset.phenotypes, train_mask=dataset.train_mask, 160 | val_mask= dataset.val_mask, test_mask=dataset.test_mask, edge_index=dataset.edge_index, 161 | num_nodes=dataset.num_nodes) 162 | return dataset 163 | 164 | def get_population_graph(self): 165 | """ 166 | Creates the population graph. 167 | """ 168 | # Load data 169 | data_df, train_idx, val_idx, test_idx, num_nodes = self.load_data() 170 | 171 | # Take phenotypes and node_features dataframes 172 | phenotypes_df = self.get_phenotypes(data_df) 173 | phenotypes = self.get_features_demographics(phenotypes_df) 174 | node_features = self.get_node_features(data_df) 175 | 176 | # Mask val & test subjects 177 | train_mask, val_mask, test_mask = self.get_subject_masks(train_idx, val_idx, test_idx) 178 | # Get the labels 179 | labels = self.get_labels(data_df) 180 | 181 | if self.task == 'classification': 182 | labels= one_hot_embedding(labels,abs(self.num_classes)) 183 | 184 | population_graph = Data(x = node_features, y= labels, phenotypes= phenotypes, train_mask= train_mask, val_mask=val_mask, test_mask=test_mask, num_nodes=num_nodes, k=self.k) 185 | # Get edges using existing pyg KNNGraph class 186 | population_graph = self.get_edges_using_KNNgraph(population_graph, k=self.k) 187 | return population_graph 188 | 189 | def one_hot_embedding(labels, num_classes): 190 | y = torch.eye(num_classes) 191 | return y[labels] 192 | 193 | class UKBBageDataset(torch.utils.data.Dataset): 194 | def __init__(self, graph, split='train', samples_per_epoch=100, device='cpu', num_classes=2) -> None: 195 | dataset = graph 196 | self.n_features = dataset.num_node_features 197 | self.num_classes = abs(num_classes) 198 | self.X = dataset.x.float().to(device) 199 | self.y = dataset.y.float().to(device) 200 | self.y = dataset.y.float().to(device) 201 | self.phenotypes = dataset.phenotypes.float().to(device) 202 | self.edge_index = dataset.edge_index.to(device) 203 | 204 | if split=='train': 205 | self.mask = dataset.train_mask.to(device) 206 | if split=='val': 207 | self.mask = dataset.val_mask.to(device) 208 | if split=='test': 209 | self.mask = dataset.test_mask.to(device) 210 | 211 | self.samples_per_epoch = samples_per_epoch 212 | def __len__(self): 213 | return self.samples_per_epoch 214 | 215 | def __getitem__(self, idx): 216 | return self.X,self.y,self.mask,self.phenotypes, self.edge_index -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | ######################################################### 5 | # This section of code has been adapted from lcosmo/DGM_pytorch# 6 | # Modified by Margarita Bintsi# 7 | ######################################################### 8 | 9 | #Euclidean distance 10 | def pairwise_euclidean_distances(x, dim=-1): 11 | dist = torch.cdist(x,x)**2 12 | return dist, x 13 | 14 | # #Poincarè disk distance r=1 (Hyperbolic) 15 | def pairwise_poincare_distances(x, dim=-1): 16 | x_norm = (x**2).sum(dim,keepdim=True) 17 | x_norm = (x_norm.sqrt()-1).relu() + 1 18 | x = x/(x_norm*(1+1e-2)) 19 | x_norm = (x**2).sum(dim,keepdim=True) 20 | 21 | pq = torch.cdist(x,x)**2 22 | dist = torch.arccosh(1e-6+1+2*pq/((1-x_norm)*(1-x_norm.transpose(-1,-2))))**2 23 | return dist, x 24 | 25 | #Cosine similarity 26 | def pairwise_cosine_distances(x, dim=-1): 27 | dist = 1 - torch.mm(torch.nn.functional.normalize(x[0], p=2, dim=-1), torch.nn.functional.normalize(x[0], p=2, dim=-1).T).unsqueeze(0) 28 | return dist, x 29 | 30 | class MLP(nn.Module): 31 | def __init__(self, layers_size,final_activation=False, dropout=0): 32 | super(MLP, self).__init__() 33 | layers = [] 34 | for li in range(1,len(layers_size)): 35 | if dropout>0: 36 | layers.append(nn.Dropout(dropout)) 37 | layers.append(nn.Linear(layers_size[li-1],layers_size[li])) 38 | if li==len(layers_size)-1 and not final_activation: 39 | continue 40 | layers.append(nn.LeakyReLU(0.1)) 41 | self.MLP = nn.Sequential(*layers) 42 | def forward(self, x, e=None): 43 | x = self.MLP(x) 44 | return x 45 | 46 | class Attention(nn.Module): 47 | def __init__(self, layers_size,final_activation=False, dropout=0): 48 | super(Attention, self).__init__() 49 | layers = [] 50 | for li in range(1,len(layers_size)): 51 | if dropout>0: 52 | layers.append(nn.Dropout(dropout)) 53 | layers.append(nn.Linear(layers_size[li-1],layers_size[li])) 54 | 55 | if li==len(layers_size)-1 and not final_activation: 56 | continue 57 | self.mlp = nn.Sequential(*layers) 58 | 59 | def forward(self, x, e=None): 60 | x = self.mlp(x) 61 | # Return attention weights per phenotype 62 | x = torch.sum(x, 1) 63 | x = (x - torch.min(x, dim=1)[0])/(torch.max(x, dim=1)[0] - torch.min(x, dim=1)[0]) # normalization 64 | return x 65 | 66 | class Identity(nn.Module): 67 | def __init__(self,retparam=None): 68 | self.retparam=retparam 69 | super(Identity, self).__init__() 70 | 71 | def forward(self, *params): 72 | if self.retparam is not None: 73 | return params[self.retparam] 74 | return params 75 | 76 | class GraphLearning(nn.Module): 77 | def __init__(self, embed_f, k=5, distance=pairwise_euclidean_distances, sparse=True): 78 | super(GraphLearning, self).__init__() 79 | 80 | self.sparse=sparse 81 | 82 | self.temperature = nn.Parameter(torch.tensor(1. if distance=="hyperbolic" else 4.).float()) 83 | 84 | self.embed_f = embed_f 85 | self.centroid=None 86 | self.scale=None 87 | self.k = k 88 | 89 | self.debug=False 90 | if distance == 'euclidean': 91 | self.distance = pairwise_euclidean_distances 92 | elif distance == 'hyperbolic': 93 | self.distance = pairwise_poincare_distances 94 | elif distance == 'cosine': 95 | self.distance = pairwise_cosine_distances 96 | else: 97 | raise ValueError('There is not this kind of distance.') 98 | 99 | def forward(self, x, A, phenotypes, not_used1=None, not_used2=None, fixedges=None): 100 | # Estimate attetion coefficients for every phenotype (weights) 101 | att_weights = self.embed_f(phenotypes) 102 | # Give attention weights to the phenotypes 103 | phenotypes_weighted = att_weights * phenotypes 104 | 105 | if self.training: 106 | D, _x = self.distance(phenotypes_weighted) 107 | #sampling here 108 | edges_hat, logprobs = self.sample_without_replacement(D) 109 | 110 | else: 111 | with torch.no_grad(): 112 | D, _x = self.distance(phenotypes_weighted) 113 | #sampling here 114 | edges_hat, logprobs = self.sample_without_replacement(D) 115 | return x, edges_hat, phenotypes, logprobs, att_weights 116 | 117 | def sample_without_replacement(self, logits): 118 | b,n,_ = logits.shape 119 | logits = logits * torch.exp(torch.clamp(self.temperature,-5,5)) 120 | 121 | q = torch.rand_like(logits) + 1e-8 122 | lq = (logits-torch.log(-torch.log(q))) 123 | logprobs, indices = torch.topk(-lq,self.k) 124 | rows = torch.arange(n).view(1,n,1).to(logits.device).repeat(b,1,self.k) 125 | edges = torch.stack((indices.view(b,-1),rows.view(b,-1)),-2) 126 | 127 | if self.sparse: 128 | return (edges+(torch.arange(b).to(logits.device)*n)[:,None,None]).transpose(0,1).reshape(2,-1), logprobs 129 | return edges, logprobs -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | ######################################################### 2 | # This section of code has been adapted from lcosmo/DGM_pytorch# 3 | # Modified by Margarita Bintsi# 4 | ######################################################### 5 | 6 | import torch 7 | import numpy as np 8 | 9 | from torch.nn import ModuleList 10 | from torch_geometric.nn import EdgeConv, ChebConv, GCNConv, GATConv, SAGEConv 11 | 12 | import pytorch_lightning as pl 13 | from argparse import Namespace 14 | import torchmetrics 15 | 16 | from layers import * 17 | 18 | class GraphLearningModel(pl.LightningModule): 19 | def __init__(self, hparams, config=None): 20 | super(GraphLearningModel,self).__init__() 21 | 22 | if type(hparams) is not Namespace: 23 | hparams = Namespace(**hparams) 24 | 25 | self.save_hyperparameters(hparams) 26 | conv_layers = hparams.conv_layers 27 | self.fc_layers = hparams.fc_layers 28 | self.dgm_layers = hparams.dgm_layers 29 | self.test_eval = hparams.test_eval 30 | self.dropout = hparams.dropout 31 | self.graph_loss_mae = hparams.graph_loss_mae 32 | self.k = hparams.k 33 | self.lr = hparams.lr 34 | self.task = hparams.task 35 | self.num_classes = hparams.num_classes 36 | self.phenotype_columns = hparams.phenotype_columns 37 | 38 | if self.task == "regression": 39 | self.criterion = 'torch.nn.HuberLoss()' 40 | # Metrics for regression 41 | self.mean_absolute_error = torchmetrics.MeanAbsoluteError() 42 | self.rscore = torchmetrics.PearsonCorrCoef() 43 | elif self.task == "classification": 44 | self.criterion= 'torch.nn.functional.binary_cross_entropy_with_logits()' 45 | # Metrics for classification 46 | self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=self.num_classes) 47 | self.auc = torchmetrics.AUROC(task="multiclass", num_classes=self.num_classes) 48 | self.f1score = torchmetrics.F1Score(task="multiclass", num_classes=self.num_classes, average='macro') 49 | else: 50 | raise ValueError('Task should be either regression or classification.') 51 | 52 | # Here we build the models 53 | self.graph_f = ModuleList() 54 | self.node_g = ModuleList() 55 | for i,(dgm_l,conv_l) in enumerate(zip(self.dgm_layers,conv_layers)): 56 | if len(dgm_l)>0: 57 | if 'ffun' not in hparams or hparams.ffun == 'phenotypes': 58 | self.graph_f.append(GraphLearning(Attention(dgm_l),k=self.k,distance=hparams.distance)) 59 | else: 60 | self.graph_f.append(Identity()) 61 | 62 | if hparams.gfun == 'edgeconv': 63 | conv_l=conv_l.copy() 64 | conv_l[0]=conv_l[0]*2 65 | self.node_g.append(EdgeConv(MLP(conv_l), hparams.pooling)) 66 | elif hparams.gfun == 'gcn': 67 | self.node_g.append(GCNConv(conv_l[0],conv_l[1])) 68 | elif hparams.gfun == 'gat': 69 | self.node_g.append(GATConv(conv_l[0],conv_l[1])) 70 | elif hparams.gfun == 'sage': 71 | self.node_g.append(SAGEConv(conv_l[0],conv_l[1])) 72 | elif hparams.gfun == 'chebconv': 73 | self.node_g.append(ChebConv(conv_l[0],conv_l[1],2)) 74 | else: 75 | raise Exception("Function %s not supported" % hparams.gfun) 76 | 77 | if self.fc_layers is not None and len(self.fc_layers)>0: 78 | self.fc = MLP(self.fc_layers, final_activation=False) 79 | if hparams.pre_fc is not None and len(hparams.pre_fc)>0: 80 | self.pre_fc = MLP(hparams.pre_fc, final_activation=True) 81 | self.avg_accuracy = None 82 | self.avg_mae = None 83 | 84 | #torch lightning specific 85 | self.automatic_optimization = False 86 | self.debug=False 87 | 88 | def forward(self,x, edges=None, phenotypes= None): 89 | if self.hparams.pre_fc is not None and len(self.hparams.pre_fc)>0: 90 | x = self.pre_fc(x) 91 | 92 | graph_x = x.detach() 93 | lprobslist = [] 94 | att_weights_list = [] 95 | for f,g in zip(self.graph_f, self.node_g): 96 | graph_x,edges,phenotypes,lprobs, att_weights = f(graph_x,edges,phenotypes, None, None) 97 | b,n,d = x.shape 98 | 99 | self.edges=edges 100 | x = torch.nn.functional.relu(g(torch.dropout(x.view(-1,d), self.dropout, train=self.training), edges)).view(b,n,-1) 101 | 102 | if lprobs is not None: 103 | lprobslist.append(lprobs) 104 | att_weights_list.append(att_weights) 105 | 106 | if self.fc_layers is not None and len(self.fc_layers)>0: 107 | return self.fc(x), torch.stack(att_weights_list,-1) if len(att_weights_list)>0 else None, torch.stack(lprobslist,-1) if len(lprobslist)>0 else None 108 | return x, torch.stack(att_weights_list,-1) if len(att_weights_list)>0 else None, torch.stack(lprobslist,-1) if len(lprobslist)>0 else None 109 | 110 | def configure_optimizers(self): 111 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 112 | return optimizer 113 | 114 | 115 | def training_step(self, train_batch, batch_idx): 116 | 117 | optimizer = self.optimizers(use_pl_optimizer=True) 118 | optimizer.zero_grad() 119 | 120 | X, y, mask, phenotypes, edges = train_batch 121 | edges = edges[0] 122 | 123 | assert(X.shape[0]==1) #only works in transductive setting 124 | mask=mask[0] 125 | 126 | pred, att_weights, logprobs = self(X, edges, phenotypes) 127 | if self.task == 'classification': 128 | train_pred = pred[:,mask.to(torch.bool),:] 129 | train_lab = y[:,mask.to(torch.bool),:] 130 | elif self.task == 'regression': 131 | pred = pred.squeeze_(-1) 132 | train_pred = pred[:,mask.to(torch.bool)] 133 | train_lab = y[:,mask.to(torch.bool)] 134 | else: 135 | raise ValueError('Task should be classification or regression.') 136 | 137 | if self.criterion == 'torch.nn.functional.cross_entropy()': 138 | loss = torch.nn.functional.cross_entropy(train_pred.view(-1,train_pred.shape[-1]),train_lab.argmax(-1).flatten()) 139 | elif self.criterion == 'torch.nn.functional.binary_cross_entropy_with_logits()': 140 | loss = torch.nn.functional.binary_cross_entropy_with_logits(train_pred,train_lab) 141 | elif (self.criterion == 'torch.nn.HuberLoss()'): 142 | loss = torch.nn.HuberLoss()(train_pred.squeeze_(), train_lab.squeeze_()) 143 | else: 144 | raise ValueError('Choose appropriate loss function.') 145 | 146 | self.manual_backward(loss) 147 | 148 | # Estimate graph loss 149 | if self.task == 'classification': 150 | correct_t = (train_pred.argmax(-1) == train_lab.argmax(-1)).float().mean().item() 151 | #GRAPH LOSS 152 | if logprobs is not None: 153 | corr_pred = (train_pred.argmax(-1)==train_lab.argmax(-1)).float().detach() #0 or 1 154 | if self.avg_accuracy is None: 155 | self.avg_accuracy = torch.ones_like(corr_pred)*0.5 156 | point_w = (self.avg_accuracy-corr_pred) 157 | graph_loss = point_w * logprobs[:,mask.to(torch.bool),:].exp().mean([-1,-2]) 158 | graph_loss = graph_loss.mean() 159 | graph_loss.backward() 160 | self.log('train_point_w', point_w.mean().detach().cpu()) 161 | self.log('train_graph_loss', graph_loss.detach().cpu()) 162 | self.avg_accuracy = self.avg_accuracy.to(corr_pred.device)*0.95 + 0.05*corr_pred 163 | optimizer.step() 164 | 165 | self.log('train_acc', 100*correct_t) 166 | self.log('train_loss', loss.detach().cpu()) 167 | if att_weights is not None: 168 | for i, weight in enumerate(torch.transpose(att_weights, 0, 1)): 169 | if weight.shape[1] > 1: 170 | weight = torch.mean(weight) 171 | name = self.phenotype_columns[i] 172 | self.log(f'train_{name}', weight.detach().cpu()) 173 | elif self.task == 'regression': 174 | abs_error = abs(train_pred.squeeze_() - train_lab.squeeze_()).mean().item() 175 | #GRAPH LOSS 176 | if logprobs is not None: 177 | mae = abs(train_pred.squeeze_() - train_lab.squeeze_()).detach() 178 | if self.avg_mae is None: 179 | self.avg_mae = torch.ones_like(mae)*self.graph_loss_mae 180 | point_w = (mae - self.avg_mae) 181 | graph_loss = (point_w * logprobs[:,mask.to(torch.bool)].exp().mean([-1,-2])) 182 | graph_loss = graph_loss.mean() 183 | graph_loss.backward() 184 | self.log('train_point_w', point_w.mean().detach().cpu()) 185 | self.log('train_graph_loss', graph_loss.detach().cpu()) 186 | self.avg_mae = self.avg_mae.to(mae.device)*0.95 + 0.05*mae 187 | optimizer.step() 188 | self.log('train_abs_error', abs_error) 189 | self.log('train_loss', loss.detach().cpu()) 190 | if att_weights is not None: 191 | for i, weight in enumerate(torch.transpose(att_weights, 0, 1)): 192 | if weight.shape[1] > 1: 193 | weight = torch.mean(weight) 194 | name = self.phenotype_columns[i] 195 | self.log(f'train_{name}', weight.detach().cpu()) 196 | else: 197 | raise ValueError('Task should be either regression or classification.') 198 | 199 | def validation_step(self, val_batch, batch_idx): 200 | X, y, mask, phenotypes, edges = val_batch 201 | edges = edges[0] 202 | assert(X.shape[0]==1) #only works in transductive setting 203 | mask=mask[0] 204 | 205 | pred, att_weights, logprobs = self(X, edges, phenotypes) 206 | if self.task == 'classification': 207 | pred=pred.softmax(-1) 208 | for i in range(1,self.test_eval): 209 | pred_, att_weights, logprobs = self(X, edges, phenotypes) 210 | pred+=pred_.softmax(-1) 211 | test_pred = pred[:,mask.to(torch.bool),:] 212 | test_lab = y[:,mask.to(torch.bool),:] 213 | elif self.task == 'regression': 214 | pred = pred.squeeze_(-1) 215 | for i in range(1,self.test_eval): 216 | pred_, att_weights, logprobs = self(X, edges, phenotypes) 217 | pred+=pred_.squeeze_(-1) 218 | pred = pred / self.test_eval 219 | test_pred = pred[:,mask.to(torch.bool)] 220 | test_lab = y[:,mask.to(torch.bool)] 221 | else: 222 | raise ValueError('Task should be classification or regression.') 223 | 224 | if self.criterion == 'torch.nn.functional.cross_entropy()': 225 | loss = torch.nn.functional.cross_entropy(test_pred.view(-1,test_pred.shape[-1]),test_lab.argmax(-1).flatten()) 226 | elif self.criterion == 'torch.nn.functional.binary_cross_entropy_with_logits()': 227 | loss = torch.nn.functional.binary_cross_entropy_with_logits(test_pred,test_lab) 228 | elif (self.criterion == 'torch.nn.HuberLoss()'): 229 | loss = torch.nn.HuberLoss()(test_pred.squeeze_(), test_lab.squeeze_()) 230 | else: 231 | raise ValueError('Choose appropriate loss function.') 232 | 233 | if self.task == 'classification': 234 | correct_t = (test_pred.argmax(-1) == test_lab.argmax(-1)).float().mean().item() 235 | 236 | accuracy = self.accuracy(test_pred.argmax(-1), test_lab.argmax(-1)) 237 | auc = self.auc(test_pred.squeeze_().softmax(-1), test_lab.squeeze_().argmax(-1)) 238 | f1score = self.f1score(test_pred.squeeze_().argmax(-1), test_lab.squeeze_().argmax(-1)) 239 | 240 | self.log('val_accuracy', accuracy) 241 | self.log('val_AUC', auc) 242 | self.log('val_f1score', f1score) 243 | 244 | self.log('val_loss', loss.detach()) 245 | self.log('val_acc', 100*correct_t) 246 | if att_weights is not None: 247 | for i, weight in enumerate(torch.transpose(att_weights, 0, 1)): 248 | name = self.phenotype_columns[i] 249 | self.log(f'val_{name}', weight.detach().cpu()) 250 | elif self.task == 'regression': 251 | abs_error = abs(test_pred.squeeze_() - test_lab.squeeze_()).mean().item() 252 | 253 | mean_absolute_error = self.mean_absolute_error(test_pred.squeeze_(),test_lab.squeeze_()) 254 | rscore = self.rscore(test_pred.squeeze_(),test_lab.squeeze_()) 255 | 256 | self.log('val_mean_absolute_error', mean_absolute_error) 257 | self.log('val_rscore', rscore) 258 | 259 | self.log('val_loss', loss) 260 | self.log('val_abs_error', abs_error) 261 | if att_weights is not None: 262 | for i, weight in enumerate(torch.transpose(att_weights, 0, 1)): 263 | name = self.phenotype_columns[i] 264 | self.log(f'val_{name}', weight.detach().cpu()) 265 | else: 266 | raise ValueError('Task should be either regression or classification.') 267 | 268 | def test_step(self, test_batch, batch_idx): 269 | X, y, mask, phenotypes, edges = test_batch 270 | edges = edges[0] 271 | assert(X.shape[0]==1) #only works in transductive setting 272 | mask=mask[0] 273 | 274 | pred, att_weights, logprobs = self(X, edges, phenotypes) 275 | if self.task == 'classification': 276 | pred=pred.softmax(-1) 277 | for i in range(1,self.test_eval): 278 | pred_, att_weights, logprobs = self(X, edges, phenotypes) 279 | pred+=pred_.softmax(-1) 280 | test_pred = pred[:,mask.to(torch.bool),:] 281 | test_lab = y[:,mask.to(torch.bool),:] 282 | elif self.task == 'regression': 283 | pred = pred.squeeze_(-1) 284 | for i in range(1,self.test_eval): 285 | pred_, att_weights, logprobs = self(X, edges, phenotypes) 286 | pred+=pred_.squeeze_(-1) 287 | pred = pred / self.test_eval 288 | test_pred = pred[:,mask.to(torch.bool)] 289 | test_lab = y[:,mask.to(torch.bool)] 290 | else: 291 | raise ValueError('Task should be classification or regression.') 292 | 293 | if self.criterion == 'torch.nn.functional.cross_entropy()': 294 | loss = torch.nn.functional.cross_entropy(test_pred.view(-1,test_pred.shape[-1]),test_lab.argmax(-1).flatten()) 295 | elif self.criterion == 'torch.nn.functional.binary_cross_entropy_with_logits()': 296 | loss = torch.nn.functional.binary_cross_entropy_with_logits(test_pred,test_lab) 297 | elif (self.criterion == 'torch.nn.HuberLoss()'): 298 | loss = torch.nn.HuberLoss()(test_pred.squeeze_(), test_lab.squeeze_()) 299 | else: 300 | raise ValueError('Choose appropriate loss function.') 301 | 302 | if self.task == 'classification': 303 | correct_t = (test_pred.argmax(-1) == test_lab.argmax(-1)).float().mean().item() 304 | 305 | accuracy = self.accuracy(test_pred.argmax(-1), test_lab.argmax(-1)) 306 | auc = self.auc(test_pred.squeeze_().softmax(-1), test_lab.squeeze_().argmax(-1)) 307 | f1score = self.f1score(test_pred.squeeze_().argmax(-1), test_lab.squeeze_().argmax(-1)) 308 | 309 | self.log('test_accuracy', accuracy) 310 | self.log('test_AUC', auc) 311 | self.log('test_f1score', f1score) 312 | 313 | self.log('test_loss', loss.detach().cpu()) 314 | self.log('test_acc', 100*correct_t) 315 | if att_weights is not None: 316 | for i, weight in enumerate(torch.transpose(att_weights, 0, 1)): 317 | name = self.phenotype_columns[i] 318 | self.log(f'test_{name}', weight.detach().cpu()) 319 | elif self.task == 'regression': 320 | abs_error = abs(test_pred.squeeze_() - test_lab.squeeze_()).mean().item() 321 | 322 | mean_absolute_error = self.mean_absolute_error(test_pred.squeeze_(), test_lab.squeeze_()) 323 | rscore = self.rscore(test_pred.squeeze_(),test_lab.squeeze_()) 324 | 325 | self.log('test_mean_absolute_error', mean_absolute_error) 326 | self.log('test_rscore', rscore) 327 | 328 | self.log('test_loss', loss) 329 | self.log('test_abs_error', abs_error) 330 | if att_weights is not None: 331 | for i, weight in enumerate(torch.transpose(att_weights, 0, 1)): 332 | name = self.phenotype_columns[i] 333 | self.log(f'test_{name}', weight.detach().cpu()) 334 | else: 335 | raise ValueError('Task should be either regression or classification.') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 3 | import json 4 | 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning.loggers import TensorBoardLogger 8 | 9 | from torch.utils.data import DataLoader 10 | from torch_geometric.loader import DataLoader 11 | 12 | from argparse import ArgumentParser 13 | from argparse import Namespace 14 | 15 | from graph_constuction import PopulationGraphUKBB, UKBBageDataset 16 | from model import GraphLearningModel 17 | 18 | run_params = { 19 | "gpus":1, 20 | "log_every_n_steps": 100, 21 | "max_epochs": 150, 22 | "progress_bar_regresh_rate":1, 23 | "check_val_every_n_epoch":1, 24 | 25 | "conv_layers": [[68,512]], 26 | "dgm_layers": [[35,35], []], 27 | "fc_layers": [512,128,1], 28 | "pre_fc": None, 29 | 30 | "gfun":'gcn', 31 | "ffun": 'phenotypes', 32 | "k": 5, 33 | "pooling": 'add', 34 | "distance": 'euclidean', 35 | 36 | "dropout": 0, 37 | "lr": 0.001, 38 | "test_eval": 10, 39 | 40 | "num_node_features": 68, 41 | "num_classes": 1, 42 | "task": 'regression', 43 | 44 | "graph_loss_mae": 6, 45 | "edges": 'phenotypes', 46 | 47 | "phenotype_columns": ['Sex', 'Height', 'Body mass index (BMI)', 'Systolic blood pressure', 'Diastolic blood pressure', 'College education', 'Smoking status', 48 | 'Alcohol intake frequency', 'Stroke', 'Diabetes', 'Walking per week', 'Vigorous per week', 'Fluid intelligence', 'Tower rearranging: number of puzzles correct', 49 | 'Trail making task: duration to complete numeric path trail 1', 'Trail making task: duration to complete alphanumeric path trail 2', 'Matrix pattern completion: number of puzzles correctly solved', 50 | 'Matrix pattern completion: duration spent answering each puzzle', 51 | 'Volume of grey matter (normalised for head size)', 'Volume of brain stem + 4th ventricle', 'Volume of grey matter in Putamen (right)', 'Volume of thalamus (right)', 52 | 'Volume of putamen (left)', 'Volume of grey matter in Thalamus (right)', 'Total volume of white matter hyperintensities (from T1 and T2_FLAIR images)', 'Weighted-mean ICVF in tract forceps minor', 53 | 'Weighted-mean L1 in tract anterior thalamic radiation (right)', 'Mean FA in cerebral peduncle on FA skeleton (left)', 54 | 'Mean FA in superior cerebellar peduncle on FA skeleton (left)', 'Mean L1 in middle cerebellar peduncle on FA skeleton', 'Mean ICVF in body of corpus callosum on FA skeleton', 55 | 'Weighted-mean L3 in tract uncinate fasciculus (left)','Mean ISOVF in fornix on FA skeleton', 56 | 'Weighted-mean L1 in tract parahippocampal part of cingulum (left)', 'Mean L2 in fornix cres+stria terminalis on FA skeleton (left)'], 57 | } 58 | 59 | def costruct_graph(run_params): 60 | """ 61 | Extract an initial population graph to be used as input to the model. 62 | """ 63 | 64 | # We have selected imaging features + non-imaging features 65 | # that are found relevant to brain-age from J.Cole's paper 66 | # https://pubmed.ncbi.nlm.nih.gov/32380363/ 67 | data_dir = 'data/' 68 | filename_train = 'train.csv' 69 | filename_val = 'val.csv' 70 | filename_test = 'test.csv' 71 | 72 | # Keep only the imaging features as node features 73 | node_columns = [0, 1, 22, 90] 74 | num_node_features = node_columns[3] - node_columns[2] 75 | 76 | task = run_params.task 77 | num_classes = run_params.num_classes 78 | k = run_params.k 79 | edges = run_params.edges 80 | 81 | #Phenotypes chosen for the extraction of the edges 82 | phenotype_columns = run_params.phenotype_columns 83 | 84 | population_graph = PopulationGraphUKBB(data_dir, filename_train, filename_val, filename_test, phenotype_columns, node_columns, 85 | num_node_features, task, num_classes, k, edges) 86 | population_graph = population_graph.get_population_graph() 87 | return population_graph 88 | 89 | def main(run_params, graph): 90 | train_data = None 91 | test_data = None 92 | 93 | # Load data 94 | train_data = UKBBageDataset(graph=graph, split='train', device='cuda', num_classes = run_params.num_classes) 95 | val_data = UKBBageDataset(graph=graph, split='val', samples_per_epoch=1, num_classes = run_params.num_classes) 96 | test_data = UKBBageDataset(graph=graph, split='test', samples_per_epoch=1, num_classes = run_params.num_classes) 97 | 98 | train_loader = DataLoader(train_data, batch_size=1,num_workers=0) 99 | val_loader = DataLoader(val_data, batch_size=1) 100 | test_loader = DataLoader(test_data, batch_size=1) 101 | 102 | class MyDataModule(pl.LightningDataModule): 103 | def setup(self,stage=None): 104 | pass 105 | def train_dataloader(self): 106 | return train_loader 107 | def val_dataloader(self): 108 | return val_loader 109 | def test_dataloader(self): 110 | return test_loader 111 | 112 | if train_data is None: 113 | raise Exception("Dataset %s not supported" % run_params.dataset) 114 | 115 | #configure input feature sizes 116 | if run_params.pre_fc is None or len(run_params.pre_fc)==0: 117 | if len(run_params.dgm_layers[0])>0: 118 | run_params.dgm_layers[0][0]=train_data.phenotypes.shape[1] 119 | run_params.conv_layers[0][0]=train_data.n_features 120 | else: 121 | run_params.pre_fc[0]= train_data.n_features 122 | 123 | if run_params.fc_layers is not None: 124 | run_params.fc_layers[-1] = train_data.num_classes 125 | 126 | model = GraphLearningModel(run_params) 127 | print(model) 128 | 129 | if run_params.task == 'regression': 130 | checkpoint_callback = ModelCheckpoint( 131 | save_last=False, 132 | save_top_k=1, 133 | verbose=False, 134 | monitor='val_loss', 135 | mode='min' 136 | ) 137 | elif run_params.task == 'classification': 138 | checkpoint_callback = ModelCheckpoint( 139 | save_last=False, 140 | save_top_k=1, 141 | verbose=False, 142 | monitor='val_acc', 143 | mode='max') 144 | else: 145 | raise ValueError('Task should be either regression or classification.') 146 | 147 | callbacks = [checkpoint_callback] 148 | if val_data==test_data: 149 | callbacks = None 150 | 151 | logger = TensorBoardLogger("logs/regression/") 152 | trainer = pl.Trainer.from_argparse_args(run_params,logger=logger, 153 | callbacks=callbacks) 154 | 155 | trainer.fit(model, datamodule=MyDataModule()) 156 | 157 | # Evaluate results on validation and test set 158 | val_results = trainer.validate(ckpt_path=checkpoint_callback.best_model_path, dataloaders=val_loader) 159 | test_results = trainer.test(ckpt_path=checkpoint_callback.best_model_path, dataloaders=test_loader) 160 | 161 | # Save results 162 | path = '/'.join(checkpoint_callback.best_model_path.split("/")[:-1]) 163 | with open(path + "/val_results.json", "w") as outfile: 164 | json.dump(val_results, outfile) 165 | with open(path + "/test_results.json", "w") as outfile: 166 | json.dump(test_results, outfile) 167 | return val_results, test_results, path 168 | 169 | if type(run_params) is not Namespace: 170 | run_params = Namespace(**run_params) 171 | 172 | population_graph = costruct_graph(run_params) 173 | val_results, test_results, best_model_checkpoint_path = main(run_params, population_graph) --------------------------------------------------------------------------------