├── Biomarkers.png ├── PLSNet.png ├── dataloader.py ├── imports ├── ABIDEDataset.py ├── __inits__.py ├── gdc.py ├── preprocess_data.py ├── read_abide_stats_parall.py └── utils.py ├── main.py ├── model ├── Encoder.py ├── __init__.py ├── __pycache__ │ ├── Encoder.cpython-37.pyc │ ├── GAU.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ └── model.cpython-37.pyc └── model.py ├── readme.md ├── setting └── abide_PLSNet.yaml ├── train.py └── util ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── logger.cpython-37.pyc ├── loss.cpython-37.pyc ├── meter.cpython-37.pyc └── prepossess.cpython-37.pyc ├── abide ├── 01-fetch_data.py ├── 02-process_data.py ├── 03-generate_abide_dataset.py ├── __pycache__ │ └── preprocess_data.cpython-37.pyc ├── preprocess_data.py ├── readme.md └── subject_IDs.txt ├── analysis └── extract_info_from_log.py ├── logger.py ├── loss.py ├── meter.py └── prepossess.py /Biomarkers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/Biomarkers.png -------------------------------------------------------------------------------- /PLSNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/PLSNet.png -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.utils.data as utils 5 | import csv 6 | 7 | from nilearn.connectome import ConnectivityMeasure 8 | from sklearn import preprocessing 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | from scipy.io import loadmat 12 | from nilearn import plotting, datasets 13 | 14 | class StandardScaler: 15 | """ 16 | Standard the input 17 | """ 18 | 19 | def __init__(self, mean, std): 20 | self.mean = mean 21 | self.std = std 22 | 23 | def transform(self, data): 24 | return (data - self.mean) / self.std 25 | 26 | def inverse_transform(self, data): 27 | return (data * self.std) + self.mean 28 | 29 | 30 | 31 | 32 | def init_dataloader(dataset_config): 33 | 34 | 35 | 36 | data = np.load(dataset_config["time_seires"], allow_pickle=True).item() 37 | final_fc = data["timeseires"] 38 | final_pearson = data["corr"] 39 | labels = data["label"] 40 | 41 | 42 | _, _, timeseries = final_fc.shape 43 | 44 | _, node_size, node_feature_size = final_pearson.shape 45 | 46 | scaler = StandardScaler(mean=np.mean( 47 | final_fc), std=np.std(final_fc)) 48 | 49 | final_fc = scaler.transform(final_fc) 50 | 51 | 52 | pseudo = [] 53 | for i in range(len(final_fc)): 54 | pseudo.append(np.diag(np.ones(final_pearson.shape[1]))) 55 | 56 | if 'cc200' in dataset_config['atlas']: 57 | pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 200, 200)) 58 | elif 'aal' in dataset_config['atlas']: 59 | pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 116, 116)) 60 | elif 'cc400' in dataset_config['atlas']: 61 | pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 392, 392)) 62 | else: 63 | pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 111, 111)) 64 | 65 | 66 | 67 | final_fc, final_pearson, labels, pseudo_arr = [torch.from_numpy( 68 | data).float() for data in (final_fc, final_pearson, labels, pseudo_arr)] 69 | 70 | length = final_fc.shape[0] 71 | train_length = int(length*dataset_config["train_set"]) 72 | val_length = int(length*dataset_config["val_set"]) 73 | 74 | 75 | dataset = utils.TensorDataset( 76 | final_fc, 77 | final_pearson, 78 | labels, 79 | pseudo_arr 80 | ) 81 | 82 | train_dataset, val_dataset, test_dataset = torch.utils.data.random_split( 83 | dataset, [train_length, val_length, length-train_length-val_length]) 84 | 85 | train_dataloader = utils.DataLoader( 86 | train_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False) 87 | 88 | val_dataloader = utils.DataLoader( 89 | val_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False) 90 | 91 | test_dataloader = utils.DataLoader( 92 | test_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False) 93 | 94 | return (train_dataloader, val_dataloader, test_dataloader), node_size, node_feature_size, timeseries 95 | -------------------------------------------------------------------------------- /imports/ABIDEDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import InMemoryDataset,Data 3 | from os.path import join, isfile 4 | from os import listdir 5 | import numpy as np 6 | import os.path as osp 7 | from imports.read_abide_stats_parall import read_data 8 | 9 | 10 | class ABIDEDataset(InMemoryDataset): 11 | def __init__(self, root, name, transform=None, pre_transform=None): 12 | self.root = root 13 | self.name = name 14 | super(ABIDEDataset, self).__init__(root,transform, pre_transform) 15 | self.data, self.slices = torch.load(self.processed_paths[0]) 16 | 17 | @property 18 | def raw_file_names(self): 19 | data_dir = osp.join(self.root,'raw') 20 | onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))] 21 | onlyfiles.sort() 22 | return onlyfiles 23 | @property 24 | def processed_file_names(self): 25 | return 'data.pt' 26 | 27 | def download(self): 28 | # Download to `self.raw_dir`. 29 | return 30 | 31 | def process(self): 32 | # Read data into huge `Data` list. 33 | self.data, self.slices = read_data(self.raw_dir) 34 | 35 | if self.pre_filter is not None: 36 | data_list = [self.get(idx) for idx in range(len(self))] 37 | data_list = [data for data in data_list if self.pre_filter(data)] 38 | self.data, self.slices = self.collate(data_list) 39 | 40 | if self.pre_transform is not None: 41 | data_list = [self.get(idx) for idx in range(len(self))] 42 | data_list = [self.pre_transform(data) for data in data_list] 43 | self.data, self.slices = self.collate(data_list) 44 | 45 | torch.save((self.data, self.slices), self.processed_paths[0]) 46 | 47 | def __repr__(self): 48 | return '{}({})'.format(self.name, len(self)) 49 | -------------------------------------------------------------------------------- /imports/__inits__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/imports/__inits__.py -------------------------------------------------------------------------------- /imports/gdc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numba 3 | import numpy as np 4 | from scipy.linalg import expm 5 | from torch_geometric.utils import add_self_loops, is_undirected, to_dense_adj 6 | from torch_sparse import coalesce 7 | from torch_scatter import scatter_add 8 | 9 | 10 | def jit(): 11 | def decorator(func): 12 | try: 13 | return numba.jit(cache=True)(func) 14 | except RuntimeError: 15 | return numba.jit(cache=False)(func) 16 | 17 | return decorator 18 | 19 | 20 | class GDC(object): 21 | r"""Processes the graph via Graph Diffusion Convolution (GDC) from the 22 | `"Diffusion Improves Graph Learning" `_ 23 | paper. 24 | .. note:: 25 | The paper offers additional advice on how to choose the 26 | hyperparameters. 27 | For an example of using GCN with GDC, see `examples/gcn.py 28 | `_. 30 | Args: 31 | self_loop_weight (float, optional): Weight of the added self-loop. 32 | Set to :obj:`None` to add no self-loops. (default: :obj:`1`) 33 | normalization_in (str, optional): Normalization of the transition 34 | matrix on the original (input) graph. Possible values: 35 | :obj:`"sym"`, :obj:`"col"`, and :obj:`"row"`. 36 | See :func:`GDC.transition_matrix` for details. 37 | (default: :obj:`"sym"`) 38 | normalization_out (str, optional): Normalization of the transition 39 | matrix on the transformed GDC (output) graph. Possible values: 40 | :obj:`"sym"`, :obj:`"col"`, :obj:`"row"`, and :obj:`None`. 41 | See :func:`GDC.transition_matrix` for details. 42 | (default: :obj:`"col"`) 43 | diffusion_kwargs (dict, optional): Dictionary containing the parameters 44 | for diffusion. 45 | `method` specifies the diffusion method (:obj:`"ppr"`, 46 | :obj:`"heat"` or :obj:`"coeff"`). 47 | Each diffusion method requires different additional parameters. 48 | See :func:`GDC.diffusion_matrix_exact` or 49 | :func:`GDC.diffusion_matrix_approx` for details. 50 | (default: :obj:`dict(method='ppr', alpha=0.15)`) 51 | sparsification_kwargs (dict, optional): Dictionary containing the 52 | parameters for sparsification. 53 | `method` specifies the sparsification method (:obj:`"threshold"` or 54 | :obj:`"topk"`). 55 | Each sparsification method requires different additional 56 | parameters. 57 | See :func:`GDC.sparsify_dense` for details. 58 | (default: :obj:`dict(method='threshold', avg_degree=64)`) 59 | exact (bool, optional): Whether to exactly calculate the diffusion 60 | matrix. 61 | Note that the exact variants are not scalable. 62 | They densify the adjacency matrix and calculate either its inverse 63 | or its matrix exponential. 64 | However, the approximate variants do not support edge weights and 65 | currently only personalized PageRank and sparsification by 66 | threshold are implemented as fast, approximate versions. 67 | (default: :obj:`True`) 68 | :rtype: :class:`torch_geometric.data.Data` 69 | """ 70 | def __init__(self, self_loop_weight=1, normalization_in='sym', 71 | normalization_out='col', 72 | diffusion_kwargs=dict(method='ppr', alpha=0.15), 73 | sparsification_kwargs=dict(method='threshold', 74 | avg_degree=64), exact=True): 75 | self.self_loop_weight = self_loop_weight 76 | self.normalization_in = normalization_in 77 | self.normalization_out = normalization_out 78 | self.diffusion_kwargs = diffusion_kwargs 79 | self.sparsification_kwargs = sparsification_kwargs 80 | self.exact = exact 81 | 82 | if self_loop_weight: 83 | assert exact or self_loop_weight == 1 84 | 85 | @torch.no_grad() 86 | def __call__(self, data): 87 | N = data.num_nodes 88 | edge_index = data.edge_index 89 | if data.edge_attr is None: 90 | edge_weight = torch.ones(edge_index.size(1), 91 | device=edge_index.device) 92 | else: 93 | edge_weight = data.edge_attr 94 | assert self.exact 95 | assert edge_weight.dim() == 1 96 | 97 | if self.self_loop_weight: 98 | edge_index, edge_weight = add_self_loops( 99 | edge_index, edge_weight, fill_value=self.self_loop_weight, 100 | num_nodes=N) 101 | 102 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) 103 | 104 | if self.exact: 105 | edge_index, edge_weight = self.transition_matrix( 106 | edge_index, edge_weight, N, self.normalization_in) 107 | diff_mat = self.diffusion_matrix_exact(edge_index, edge_weight, N, 108 | **self.diffusion_kwargs) 109 | edge_index, edge_weight = self.sparsify_dense( 110 | diff_mat, **self.sparsification_kwargs) 111 | else: 112 | edge_index, edge_weight = self.diffusion_matrix_approx( 113 | edge_index, edge_weight, N, self.normalization_in, 114 | **self.diffusion_kwargs) 115 | edge_index, edge_weight = self.sparsify_sparse( 116 | edge_index, edge_weight, N, **self.sparsification_kwargs) 117 | 118 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) 119 | edge_index, edge_weight = self.transition_matrix( 120 | edge_index, edge_weight, N, self.normalization_out) 121 | 122 | data.edge_index = edge_index 123 | data.edge_attr = edge_weight 124 | 125 | return data 126 | 127 | def transition_matrix(self, edge_index, edge_weight, num_nodes, 128 | normalization): 129 | r"""Calculate the approximate, sparse diffusion on a given sparse 130 | matrix. 131 | Args: 132 | edge_index (LongTensor): The edge indices. 133 | edge_weight (Tensor): One-dimensional edge weights. 134 | num_nodes (int): Number of nodes. 135 | normalization (str): Normalization scheme: 136 | 1. :obj:`"sym"`: Symmetric normalization 137 | :math:`\mathbf{T} = \mathbf{D}^{-1/2} \mathbf{A} 138 | \mathbf{D}^{-1/2}`. 139 | 2. :obj:`"col"`: Column-wise normalization 140 | :math:`\mathbf{T} = \mathbf{A} \mathbf{D}^{-1}`. 141 | 3. :obj:`"row"`: Row-wise normalization 142 | :math:`\mathbf{T} = \mathbf{D}^{-1} \mathbf{A}`. 143 | 4. :obj:`None`: No normalization. 144 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 145 | """ 146 | if normalization == 'sym': 147 | row, col = edge_index 148 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 149 | deg_inv_sqrt = deg.pow(-0.5) 150 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 151 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 152 | elif normalization == 'col': 153 | _, col = edge_index 154 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 155 | deg_inv = 1. / deg 156 | deg_inv[deg_inv == float('inf')] = 0 157 | edge_weight = edge_weight * deg_inv[col] 158 | elif normalization == 'row': 159 | row, _ = edge_index 160 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 161 | deg_inv = 1. / deg 162 | deg_inv[deg_inv == float('inf')] = 0 163 | edge_weight = edge_weight * deg_inv[row] 164 | elif normalization is None: 165 | pass 166 | else: 167 | raise ValueError( 168 | 'Transition matrix normalization {} unknown.'.format( 169 | normalization)) 170 | 171 | return edge_index, edge_weight 172 | 173 | def diffusion_matrix_exact(self, edge_index, edge_weight, num_nodes, 174 | method, **kwargs): 175 | r"""Calculate the (dense) diffusion on a given sparse graph. 176 | Note that these exact variants are not scalable. They densify the 177 | adjacency matrix and calculate either its inverse or its matrix 178 | exponential. 179 | Args: 180 | edge_index (LongTensor): The edge indices. 181 | edge_weight (Tensor): One-dimensional edge weights. 182 | num_nodes (int): Number of nodes. 183 | method (str): Diffusion method: 184 | 1. :obj:`"ppr"`: Use personalized PageRank as diffusion. 185 | Additionally expects the parameter: 186 | - **alpha** (*float*) - Return probability in PPR. 187 | Commonly lies in :obj:`[0.05, 0.2]`. 188 | 2. :obj:`"heat"`: Use heat kernel diffusion. 189 | Additionally expects the parameter: 190 | - **t** (*float*) - Time of diffusion. Commonly lies in 191 | :obj:`[2, 10]`. 192 | 3. :obj:`"coeff"`: Freely choose diffusion coefficients. 193 | Additionally expects the parameter: 194 | - **coeffs** (*List[float]*) - List of coefficients 195 | :obj:`theta_k` for each power of the transition matrix 196 | (starting at :obj:`0`). 197 | :rtype: (:class:`Tensor`) 198 | """ 199 | if method == 'ppr': 200 | # α (I_n + (α - 1) A)^-1 201 | edge_weight = (kwargs['alpha'] - 1) * edge_weight 202 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 203 | fill_value=1, 204 | num_nodes=num_nodes) 205 | mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() 206 | diff_matrix = kwargs['alpha'] * torch.inverse(mat) 207 | 208 | elif method == 'heat': 209 | # exp(t (A - I_n)) 210 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 211 | fill_value=-1, 212 | num_nodes=num_nodes) 213 | edge_weight = kwargs['t'] * edge_weight 214 | mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() 215 | undirected = is_undirected(edge_index, edge_weight, num_nodes) 216 | diff_matrix = self.__expm__(mat, undirected) 217 | 218 | elif method == 'coeff': 219 | adj_matrix = to_dense_adj(edge_index, 220 | edge_attr=edge_weight).squeeze() 221 | mat = torch.eye(num_nodes, device=edge_index.device) 222 | 223 | diff_matrix = kwargs['coeffs'][0] * mat 224 | for coeff in kwargs['coeffs'][1:]: 225 | mat = mat @ adj_matrix 226 | diff_matrix += coeff * mat 227 | else: 228 | raise ValueError('Exact GDC diffusion {} unknown.'.format(method)) 229 | 230 | return diff_matrix 231 | 232 | def diffusion_matrix_approx(self, edge_index, edge_weight, num_nodes, 233 | normalization, method, **kwargs): 234 | r"""Calculate the approximate, sparse diffusion on a given sparse 235 | graph. 236 | Args: 237 | edge_index (LongTensor): The edge indices. 238 | edge_weight (Tensor): One-dimensional edge weights. 239 | num_nodes (int): Number of nodes. 240 | normalization (str): Transition matrix normalization scheme 241 | (:obj:`"sym"`, :obj:`"row"`, or :obj:`"col"`). 242 | See :func:`GDC.transition_matrix` for details. 243 | method (str): Diffusion method: 244 | 1. :obj:`"ppr"`: Use personalized PageRank as diffusion. 245 | Additionally expects the parameters: 246 | - **alpha** (*float*) - Return probability in PPR. 247 | Commonly lies in :obj:`[0.05, 0.2]`. 248 | - **eps** (*float*) - Threshold for PPR calculation stopping 249 | criterion (:obj:`edge_weight >= eps * out_degree`). 250 | Recommended default: :obj:`1e-4`. 251 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 252 | """ 253 | if method == 'ppr': 254 | if normalization == 'sym': 255 | # Calculate original degrees. 256 | _, col = edge_index 257 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 258 | 259 | edge_index_np = edge_index.cpu().numpy() 260 | # Assumes coalesced edge_index. 261 | _, indptr, out_degree = np.unique(edge_index_np[0], 262 | return_index=True, 263 | return_counts=True) 264 | 265 | neighbors, neighbor_weights = GDC.__calc_ppr__( 266 | indptr, edge_index_np[1], out_degree, kwargs['alpha'], 267 | kwargs['eps']) 268 | ppr_normalization = 'col' if normalization == 'col' else 'row' 269 | edge_index, edge_weight = self.__neighbors_to_graph__( 270 | neighbors, neighbor_weights, ppr_normalization, 271 | device=edge_index.device) 272 | edge_index = edge_index.to(torch.long) 273 | 274 | if normalization == 'sym': 275 | # We can change the normalization from row-normalized to 276 | # symmetric by multiplying the resulting matrix with D^{1/2} 277 | # from the left and D^{-1/2} from the right. 278 | # Since we use the original degrees for this it will be like 279 | # we had used symmetric normalization from the beginning 280 | # (except for errors due to approximation). 281 | row, col = edge_index 282 | deg_inv = deg.sqrt() 283 | deg_inv_sqrt = deg.pow(-0.5) 284 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 285 | edge_weight = deg_inv[row] * edge_weight * deg_inv_sqrt[col] 286 | elif normalization in ['col', 'row']: 287 | pass 288 | else: 289 | raise ValueError( 290 | ('Transition matrix normalization {} not implemented for ' 291 | 'non-exact GDC computation.').format(normalization)) 292 | 293 | elif method == 'heat': 294 | raise NotImplementedError( 295 | ('Currently no fast heat kernel is implemented. You are ' 296 | 'welcome to create one yourself, e.g., based on ' 297 | '"Kloster and Gleich: Heat kernel based community detection ' 298 | '(KDD 2014)."')) 299 | else: 300 | raise ValueError( 301 | 'Approximate GDC diffusion {} unknown.'.format(method)) 302 | 303 | return edge_index, edge_weight 304 | 305 | def sparsify_dense(self, matrix, method, **kwargs): 306 | r"""Sparsifies the given dense matrix. 307 | Args: 308 | matrix (Tensor): Matrix to sparsify. 309 | num_nodes (int): Number of nodes. 310 | method (str): Method of sparsification. Options: 311 | 1. :obj:`"threshold"`: Remove all edges with weights smaller 312 | than :obj:`eps`. 313 | Additionally expects one of these parameters: 314 | - **eps** (*float*) - Threshold to bound edges at. 315 | - **avg_degree** (*int*) - If :obj:`eps` is not given, 316 | it can optionally be calculated by calculating the 317 | :obj:`eps` required to achieve a given :obj:`avg_degree`. 318 | 2. :obj:`"topk"`: Keep edges with top :obj:`k` edge weights per 319 | node (column). 320 | Additionally expects the following parameters: 321 | - **k** (*int*) - Specifies the number of edges to keep. 322 | - **dim** (*int*) - The axis along which to take the top 323 | :obj:`k`. 324 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 325 | """ 326 | assert matrix.shape[0] == matrix.shape[1] 327 | N = matrix.shape[1] 328 | 329 | if method == 'threshold': 330 | if 'eps' not in kwargs.keys(): 331 | kwargs['eps'] = self.__calculate_eps__(matrix, N, 332 | kwargs['avg_degree']) 333 | 334 | edge_index = torch.nonzero(matrix >= kwargs['eps']).t() 335 | edge_index_flat = edge_index[0] * N + edge_index[1] 336 | edge_weight = matrix.flatten()[edge_index_flat] 337 | 338 | elif method == 'topk': 339 | assert kwargs['dim'] in [0, 1] 340 | sort_idx = torch.argsort(matrix, dim=kwargs['dim'], 341 | descending=True) 342 | if kwargs['dim'] == 0: 343 | top_idx = sort_idx[:kwargs['k']] 344 | edge_weight = torch.gather(matrix, dim=kwargs['dim'], 345 | index=top_idx).flatten() 346 | 347 | row_idx = torch.arange(0, N, device=matrix.device).repeat( 348 | kwargs['k']) 349 | edge_index = torch.stack([top_idx.flatten(), row_idx], dim=0) 350 | else: 351 | top_idx = sort_idx[:, :kwargs['k']] 352 | edge_weight = torch.gather(matrix, dim=kwargs['dim'], 353 | index=top_idx).flatten() 354 | 355 | col_idx = torch.arange( 356 | 0, N, device=matrix.device).repeat_interleave(kwargs['k']) 357 | edge_index = torch.stack([col_idx, top_idx.flatten()], dim=0) 358 | else: 359 | raise ValueError('GDC sparsification {} unknown.'.format(method)) 360 | 361 | return edge_index, edge_weight 362 | 363 | def sparsify_sparse(self, edge_index, edge_weight, num_nodes, method, 364 | **kwargs): 365 | r"""Sparsifies a given sparse graph further. 366 | Args: 367 | edge_index (LongTensor): The edge indices. 368 | edge_weight (Tensor): One-dimensional edge weights. 369 | num_nodes (int): Number of nodes. 370 | method (str): Method of sparsification: 371 | 1. :obj:`"threshold"`: Remove all edges with weights smaller 372 | than :obj:`eps`. 373 | Additionally expects one of these parameters: 374 | - **eps** (*float*) - Threshold to bound edges at. 375 | - **avg_degree** (*int*) - If :obj:`eps` is not given, 376 | it can optionally be calculated by calculating the 377 | :obj:`eps` required to achieve a given :obj:`avg_degree`. 378 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 379 | """ 380 | if method == 'threshold': 381 | if 'eps' not in kwargs.keys(): 382 | kwargs['eps'] = self.__calculate_eps__(edge_weight, num_nodes, 383 | kwargs['avg_degree']) 384 | 385 | remaining_edge_idx = torch.nonzero( 386 | edge_weight >= kwargs['eps']).flatten() 387 | edge_index = edge_index[:, remaining_edge_idx] 388 | edge_weight = edge_weight[remaining_edge_idx] 389 | elif method == 'topk': 390 | raise NotImplementedError( 391 | 'Sparse topk sparsification not implemented.') 392 | else: 393 | raise ValueError('GDC sparsification {} unknown.'.format(method)) 394 | 395 | return edge_index, edge_weight 396 | 397 | def __expm__(self, matrix, symmetric): 398 | r"""Calculates matrix exponential. 399 | Args: 400 | matrix (Tensor): Matrix to take exponential of. 401 | symmetric (bool): Specifies whether the matrix is symmetric. 402 | :rtype: (:class:`Tensor`) 403 | """ 404 | if symmetric: 405 | e, V = torch.symeig(matrix, eigenvectors=True) 406 | diff_mat = V @ torch.diag(e.exp()) @ V.t() 407 | else: 408 | diff_mat_np = expm(matrix.cpu().numpy()) 409 | diff_mat = torch.Tensor(diff_mat_np).to(matrix.device) 410 | return diff_mat 411 | 412 | def __calculate_eps__(self, matrix, num_nodes, avg_degree): 413 | r"""Calculates threshold necessary to achieve a given average degree. 414 | Args: 415 | matrix (Tensor): Adjacency matrix or edge weights. 416 | num_nodes (int): Number of nodes. 417 | avg_degree (int): Target average degree. 418 | :rtype: (:class:`float`) 419 | """ 420 | sorted_edges = torch.sort(matrix.flatten(), descending=True).values 421 | if avg_degree * num_nodes > len(sorted_edges): 422 | return -np.inf 423 | return sorted_edges[avg_degree * num_nodes - 1] 424 | 425 | def __neighbors_to_graph__(self, neighbors, neighbor_weights, 426 | normalization='row', device='cpu'): 427 | r"""Combine a list of neighbors and neighbor weights to create a sparse 428 | graph. 429 | Args: 430 | neighbors (List[List[int]]): List of neighbors for each node. 431 | neighbor_weights (List[List[float]]): List of weights for the 432 | neighbors of each node. 433 | normalization (str): Normalization of resulting matrix 434 | (options: :obj:`"row"`, :obj:`"col"`). (default: :obj:`"row"`) 435 | device (torch.device): Device to create output tensors on. 436 | (default: :obj:`"cpu"`) 437 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 438 | """ 439 | edge_weight = torch.Tensor(np.concatenate(neighbor_weights)).to(device) 440 | i = np.repeat(np.arange(len(neighbors)), 441 | np.fromiter(map(len, neighbors), dtype=np.int)) 442 | j = np.concatenate(neighbors) 443 | if normalization == 'col': 444 | edge_index = torch.Tensor(np.vstack([j, i])).to(device) 445 | N = len(neighbors) 446 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) 447 | elif normalization == 'row': 448 | edge_index = torch.Tensor(np.vstack([i, j])).to(device) 449 | else: 450 | raise ValueError( 451 | f"PPR matrix normalization {normalization} unknown.") 452 | return edge_index, edge_weight 453 | 454 | @staticmethod 455 | @jit() 456 | def __calc_ppr__(indptr, indices, out_degree, alpha, eps): 457 | r"""Calculate the personalized PageRank vector for all nodes 458 | using a variant of the Andersen algorithm 459 | (see Andersen et al. :Local Graph Partitioning using PageRank Vectors.) 460 | Args: 461 | indptr (np.ndarray): Index pointer for the sparse matrix 462 | (CSR-format). 463 | indices (np.ndarray): Indices of the sparse matrix entries 464 | (CSR-format). 465 | out_degree (np.ndarray): Out-degree of each node. 466 | alpha (float): Alpha of the PageRank to calculate. 467 | eps (float): Threshold for PPR calculation stopping criterion 468 | (:obj:`edge_weight >= eps * out_degree`). 469 | :rtype: (:class:`List[List[int]]`, :class:`List[List[float]]`) 470 | """ 471 | alpha_eps = alpha * eps 472 | js = [] 473 | vals = [] 474 | for inode in range(len(out_degree)): 475 | p = {inode: 0.0} 476 | r = {} 477 | r[inode] = alpha 478 | q = [inode] 479 | while len(q) > 0: 480 | unode = q.pop() 481 | 482 | res = r[unode] if unode in r else 0 483 | if unode in p: 484 | p[unode] += res 485 | else: 486 | p[unode] = res 487 | r[unode] = 0 488 | for vnode in indices[indptr[unode]:indptr[unode + 1]]: 489 | _val = (1 - alpha) * res / out_degree[unode] 490 | if vnode in r: 491 | r[vnode] += _val 492 | else: 493 | r[vnode] = _val 494 | 495 | res_vnode = r[vnode] if vnode in r else 0 496 | if res_vnode >= alpha_eps * out_degree[vnode]: 497 | if vnode not in q: 498 | q.append(vnode) 499 | js.append(list(p.keys())) 500 | vals.append(list(p.values())) 501 | return js, vals 502 | 503 | def __repr__(self): 504 | return '{}()'.format(self.__class__.__name__) -------------------------------------------------------------------------------- /imports/preprocess_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Mwiza Kunda 2 | # Copyright (C) 2017 Sarah Parisot , Sofia Ira Ktena 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implcd ied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | import torch 18 | import os 19 | import warnings 20 | import glob 21 | import csv 22 | import deepdish as dd 23 | import re 24 | import numpy as np 25 | import scipy.io as sio 26 | import sys 27 | from nilearn import connectome 28 | import pandas as pd 29 | import os.path as osp 30 | from functools import partial 31 | import multiprocessing 32 | import networkx as nx 33 | from networkx.convert_matrix import from_numpy_matrix 34 | from torch_sparse import coalesce 35 | from torch_geometric.utils import remove_self_loops 36 | from scipy.spatial import distance 37 | from scipy import signal 38 | from imports import utils 39 | from sklearn.compose import ColumnTransformer 40 | from sklearn.preprocessing import Normalizer 41 | from sklearn.preprocessing import OrdinalEncoder 42 | from sklearn.preprocessing import OneHotEncoder 43 | from sklearn.preprocessing import StandardScaler 44 | warnings.filterwarnings("ignore") 45 | 46 | # Input data variables 47 | 48 | # root_folder = '../data' 49 | root_folder = '/data/CodeGoat24/data' 50 | 51 | # data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal').replace('\\', '/') 52 | data_folder = '/data/CodeGoat24/data/ABIDE_pcp/cpac/filt_noglobal' 53 | phenotype = '/data/CodeGoat24/data/ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv' 54 | # phenotype = os.path.join(root_folder, 'ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv').replace('\\', '/') 55 | 56 | 57 | def fetch_filenames(subject_IDs, file_type, atlas): 58 | """ 59 | subject_list : list of short subject IDs in string format 60 | file_type : must be one of the available file types 61 | filemapping : resulting file name format 62 | returns: 63 | filenames : list of filetypes (same length as subject_list) 64 | """ 65 | 66 | filemapping = {'func_preproc': '_func_preproc.nii.gz', 67 | 'rois_' + atlas: '_rois_' + atlas + '.1D'} 68 | # The list to be filled 69 | filenames = [] 70 | 71 | # Fill list with requested file paths 72 | for i in range(len(subject_IDs)): 73 | os.chdir(data_folder) 74 | try: 75 | try: 76 | os.chdir(data_folder) 77 | filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0]) 78 | except: 79 | os.chdir(data_folder + '/' + subject_IDs[i]) 80 | filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0]) 81 | except IndexError: 82 | filenames.append('N/A') 83 | return filenames 84 | 85 | 86 | # Get timeseries arrays for list of subjects 87 | def get_timeseries(subject_list, atlas_name="cc200", silence=True, data_folder = '/data/CodeGoat24/data/ABIDE_pcp/cpac/filt_noglobal'): 88 | """ 89 | subject_list : list of short subject IDs in string format 90 | atlas_name : the atlas based on which the timeseries are generated e.g. aal, cc200 91 | returns: 92 | time_series : list of timeseries arrays, each of shape (timepoints x regions) 93 | """ 94 | 95 | timeseries = [] 96 | for i in range(len(subject_list)): 97 | subject_folder = os.path.join(data_folder, str(subject_list[i])) 98 | ro_file = [f for f in os.listdir(subject_folder) if f.endswith('_rois_' + atlas_name + '.1D')] 99 | fl = os.path.join(subject_folder, ro_file[0]) 100 | if silence != True: 101 | print("Reading timeseries file %s" % fl) 102 | t = np.loadtxt(fl, skiprows=0).transpose(1,0) 103 | timeseries.append(t[:, :76]) 104 | return torch.tensor(timeseries).float() 105 | 106 | def get_pcorr(dataID): 107 | filenames = [] 108 | for id in dataID.cpu().numpy(): 109 | filenames.append(str(id)+".h5") 110 | pcorr = [] 111 | for filename in filenames: 112 | temp = dd.io.load(osp.join("/data/CodeGoat24/data/ABIDE_pcp/cpac/filt_noglobal/raw", filename)) 113 | # read edge and edge attribute 114 | pcorr = np.append(pcorr, np.abs(temp['pcorr'][()])) 115 | pcorr = torch.from_numpy(pcorr).float() 116 | return pcorr.reshape(len(filenames), 200, 200) 117 | 118 | def compute_edge_info(pcorr): 119 | num_nodes = pcorr.shape[0] 120 | pcorr = pcorr.to_sparse() 121 | 122 | edge_att = pcorr.values() 123 | edge_index = torch.stack([pcorr.indices()[0], pcorr.indices()[1]]) 124 | edge_index, edge_att = remove_self_loops(edge_index, edge_att) 125 | edge_index = edge_index.long() 126 | edge_index, edge_att = coalesce(edge_index, edge_att, num_nodes, 127 | num_nodes) 128 | 129 | return edge_att, edge_index, num_nodes 130 | 131 | 132 | def get_edge_index_attr(pcorr): 133 | 134 | edge_att_list = [] 135 | edge_index_list = [] 136 | for j, pco in enumerate(pcorr): 137 | edge_att, edge_index, num_nodes = compute_edge_info(pco) 138 | edge_att_list.append(edge_att) 139 | edge_index_list.append(edge_index + j * num_nodes) 140 | 141 | edge_att_arr = torch.cat(edge_att_list) 142 | edge_index_arr = torch.cat(edge_index_list, axis=1) 143 | 144 | edge_att_torch = edge_att_arr.reshape(len(edge_att_arr), 1).float() 145 | edge_index_torch = edge_index_arr.long() 146 | 147 | return edge_att_torch, edge_index_torch 148 | 149 | 150 | # compute connectivity matrices 151 | def subject_connectivity(timeseries, subjects, atlas_name, kind, iter_no='', seed=1234, 152 | n_subjects='', save=True, save_path=data_folder): 153 | """ 154 | timeseries : timeseries table for subject (timepoints x regions) 155 | subjects : subject IDs 156 | atlas_name : name of the parcellation atlas used 157 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation 158 | iter_no : tangent connectivity iteration number for cross validation evaluation 159 | save : save the connectivity matrix to a file 160 | save_path : specify path to save the matrix if different from subject folder 161 | returns: 162 | connectivity : connectivity matrix (regions x regions) 163 | """ 164 | 165 | if kind in ['TPE', 'TE', 'correlation','partial correlation']: 166 | if kind not in ['TPE', 'TE']: 167 | conn_measure = connectome.ConnectivityMeasure(kind=kind) 168 | connectivity = conn_measure.fit_transform(timeseries) 169 | else: 170 | if kind == 'TPE': 171 | conn_measure = connectome.ConnectivityMeasure(kind='correlation') 172 | conn_mat = conn_measure.fit_transform(timeseries) 173 | conn_measure = connectome.ConnectivityMeasure(kind='tangent') 174 | connectivity_fit = conn_measure.fit(conn_mat) 175 | connectivity = connectivity_fit.transform(conn_mat) 176 | else: 177 | conn_measure = connectome.ConnectivityMeasure(kind='tangent') 178 | connectivity_fit = conn_measure.fit(timeseries) 179 | connectivity = connectivity_fit.transform(timeseries) 180 | 181 | if save: 182 | if kind not in ['TPE', 'TE']: 183 | for i, subj_id in enumerate(subjects): 184 | subject_file = os.path.join(save_path, subj_id, 185 | subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '.mat') 186 | sio.savemat(subject_file, {'connectivity': connectivity[i]}) 187 | return connectivity 188 | else: 189 | for i, subj_id in enumerate(subjects): 190 | subject_file = os.path.join(save_path, subj_id, 191 | subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '_' + str( 192 | iter_no) + '_' + str(seed) + '_' + validation_ext + str( 193 | n_subjects) + '.mat') 194 | sio.savemat(subject_file, {'connectivity': connectivity[i]}) 195 | return connectivity_fit 196 | 197 | 198 | # Get the list of subject IDs 199 | 200 | def get_ids(num_subjects=None): 201 | """ 202 | return: 203 | subject_IDs : list of all subject IDs 204 | """ 205 | 206 | subject_IDs = np.genfromtxt(os.path.join(data_folder, 'subject_IDs.txt').replace('\\','/'), dtype=str) 207 | # subject_IDs = np.genfromtxt('./data/subject_ID.txt', dtype=str) 208 | 209 | if num_subjects is not None: 210 | subject_IDs = subject_IDs[:num_subjects] 211 | 212 | return subject_IDs 213 | 214 | 215 | # Get phenotype values for a list of subjects 216 | def get_subject_score(subject_list, score): 217 | scores_dict = {} 218 | 219 | with open(phenotype) as csv_file: 220 | reader = csv.DictReader(csv_file) 221 | for row in reader: 222 | if row['SUB_ID'] in subject_list: 223 | if score == 'HANDEDNESS_CATEGORY': 224 | if (row[score].strip() == '-9999') or (row[score].strip() == ''): 225 | scores_dict[row['SUB_ID']] = 'R' 226 | elif row[score] == 'Mixed': 227 | scores_dict[row['SUB_ID']] = 'Ambi' 228 | elif row[score] == 'L->R': 229 | scores_dict[row['SUB_ID']] = 'Ambi' 230 | else: 231 | scores_dict[row['SUB_ID']] = row[score] 232 | elif (score == 'FIQ' or score == 'PIQ' or score == 'VIQ'): 233 | if (row[score].strip() == '-9999') or (row[score].strip() == ''): 234 | scores_dict[row['SUB_ID']] = 100 235 | else: 236 | scores_dict[row['SUB_ID']] = float(row[score]) 237 | 238 | else: 239 | scores_dict[row['SUB_ID']] = row[score] 240 | 241 | return scores_dict 242 | 243 | 244 | # preprocess phenotypes. Categorical -> ordinal representation 245 | def preprocess_phenotypes(pheno_ft, params): 246 | if params['model'] == 'MIDA': 247 | ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2])], remainder='passthrough') 248 | else: 249 | ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2, 3])], remainder='passthrough') 250 | 251 | pheno_ft = ct.fit_transform(pheno_ft) 252 | pheno_ft = pheno_ft.astype('float32') 253 | 254 | return (pheno_ft) 255 | 256 | 257 | # create phenotype feature vector to concatenate with fmri feature vectors 258 | def phenotype_ft_vector(pheno_ft, num_subjects, params): 259 | gender = pheno_ft[:, 0] 260 | if params['model'] == 'MIDA': 261 | eye = pheno_ft[:, 0] 262 | hand = pheno_ft[:, 2] 263 | age = pheno_ft[:, 3] 264 | fiq = pheno_ft[:, 4] 265 | else: 266 | eye = pheno_ft[:, 2] 267 | hand = pheno_ft[:, 3] 268 | age = pheno_ft[:, 4] 269 | fiq = pheno_ft[:, 5] 270 | 271 | phenotype_ft = np.zeros((num_subjects, 4)) 272 | phenotype_ft_eye = np.zeros((num_subjects, 2)) 273 | phenotype_ft_hand = np.zeros((num_subjects, 3)) 274 | 275 | for i in range(num_subjects): 276 | phenotype_ft[i, int(gender[i])] = 1 277 | phenotype_ft[i, -2] = age[i] 278 | phenotype_ft[i, -1] = fiq[i] 279 | phenotype_ft_eye[i, int(eye[i])] = 1 280 | phenotype_ft_hand[i, int(hand[i])] = 1 281 | 282 | if params['model'] == 'MIDA': 283 | phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand], axis=1) 284 | else: 285 | phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand, phenotype_ft_eye], axis=1) 286 | 287 | return phenotype_ft 288 | 289 | 290 | # Load precomputed fMRI connectivity networks 291 | def get_networks(subject_list, kind, iter_no='', seed=1234, n_subjects='', atlas_name="aal", 292 | variable='connectivity'): 293 | """ 294 | subject_list : list of subject IDs 295 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation 296 | atlas_name : name of the parcellation atlas used 297 | variable : variable name in the .mat file that has been used to save the precomputed networks 298 | return: 299 | matrix : feature matrix of connectivity networks (num_subjects x network_size) 300 | """ 301 | 302 | all_networks = [] 303 | for subject in subject_list: 304 | if len(kind.split()) == 2: 305 | kind = '_'.join(kind.split()) 306 | fl = os.path.join(data_folder, subject, 307 | subject + "_" + atlas_name + "_" + kind.replace(' ', '_') + ".mat") 308 | 309 | 310 | matrix = sio.loadmat(fl)[variable] 311 | all_networks.append(matrix) 312 | 313 | if kind in ['TE', 'TPE']: 314 | norm_networks = [mat for mat in all_networks] 315 | else: 316 | norm_networks = [np.arctanh(mat) for mat in all_networks] 317 | 318 | networks = np.stack(norm_networks) 319 | 320 | return networks 321 | 322 | -------------------------------------------------------------------------------- /imports/read_abide_stats_parall.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Xiaoxiao Li 3 | Date: 2019/02/24 4 | ''' 5 | 6 | import os.path as osp 7 | from os import listdir 8 | import os 9 | import glob 10 | import h5py 11 | 12 | import torch 13 | import numpy as np 14 | from scipy.io import loadmat 15 | from torch_geometric.data import Data 16 | import networkx as nx 17 | from networkx.convert_matrix import from_numpy_matrix 18 | import multiprocessing 19 | from torch_sparse import coalesce 20 | from torch_geometric.utils import remove_self_loops 21 | from functools import partial 22 | import deepdish as dd 23 | from imports.gdc import GDC 24 | 25 | 26 | 27 | 28 | def split(data, batch): 29 | node_slice = torch.cumsum(torch.from_numpy(np.bincount(batch)), 0) 30 | node_slice = torch.cat([torch.tensor([0]), node_slice]) 31 | 32 | row, _ = data.edge_index 33 | edge_slice = torch.cumsum(torch.from_numpy(np.bincount(batch[row])), 0) 34 | edge_slice = torch.cat([torch.tensor([0]), edge_slice]) 35 | 36 | # Edge indices should start at zero for every graph. 37 | data.edge_index -= node_slice[batch[row]].unsqueeze(0) 38 | 39 | slices = {'edge_index': edge_slice} 40 | if data.x is not None: 41 | slices['x'] = node_slice 42 | if data.edge_attr is not None: 43 | slices['edge_attr'] = edge_slice 44 | if data.y is not None: 45 | if data.y.size(0) == batch.size(0): 46 | slices['y'] = node_slice 47 | else: 48 | slices['y'] = torch.arange(0, batch[-1] + 2, dtype=torch.long) 49 | if data.pos is not None: 50 | slices['pos'] = node_slice 51 | 52 | return data, slices 53 | 54 | 55 | def cat(seq): 56 | seq = [item for item in seq if item is not None] 57 | seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq] 58 | return torch.cat(seq, dim=-1).squeeze() if len(seq) > 0 else None 59 | 60 | class NoDaemonProcess(multiprocessing.Process): 61 | @property 62 | def daemon(self): 63 | return False 64 | 65 | @daemon.setter 66 | def daemon(self, value): 67 | pass 68 | 69 | 70 | class NoDaemonContext(type(multiprocessing.get_context())): 71 | Process = NoDaemonProcess 72 | 73 | 74 | def read_data(data_dir): 75 | onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))] 76 | onlyfiles.sort() 77 | batch = [] 78 | pseudo = [] 79 | y_list = [] 80 | edge_att_list, edge_index_list,att_list = [], [], [] 81 | 82 | 83 | # parallar computing 84 | cores = multiprocessing.cpu_count() 85 | pool = multiprocessing.Pool(processes=cores) 86 | #pool = MyPool(processes = cores) 87 | func = partial(read_sigle_data, data_dir) 88 | 89 | import timeit 90 | 91 | start = timeit.default_timer() 92 | 93 | res = pool.map(func, onlyfiles) 94 | 95 | pool.close() 96 | pool.join() 97 | 98 | stop = timeit.default_timer() 99 | 100 | print('Time: ', stop - start) 101 | 102 | 103 | 104 | for j in range(len(res)): 105 | edge_att_list.append(res[j][0]) 106 | edge_index_list.append(res[j][1]+j*res[j][4]) 107 | att_list.append(res[j][2]) 108 | y_list.append(res[j][3]) 109 | batch.append([j]*res[j][4]) 110 | pseudo.append(np.diag(np.ones(res[j][4]))) 111 | 112 | 113 | edge_att_arr = np.concatenate(edge_att_list) 114 | edge_index_arr = np.concatenate(edge_index_list, axis=1) 115 | att_arr = np.concatenate(att_list, axis=0) 116 | pseudo_arr = np.concatenate(pseudo, axis=0) 117 | y_arr = np.stack(y_list) 118 | 119 | 120 | edge_att_torch = torch.from_numpy(edge_att_arr.reshape(len(edge_att_arr), 1)).float() 121 | att_torch = torch.from_numpy(att_arr).float() 122 | y_torch = torch.from_numpy(y_arr).long() # classification 123 | batch_torch = torch.from_numpy(np.hstack(batch)).long() 124 | edge_index_torch = torch.from_numpy(edge_index_arr).long() 125 | pseudo_torch = torch.from_numpy(pseudo_arr).float() 126 | 127 | 128 | data = Data(x=att_torch, edge_index=edge_index_torch, y=y_torch, edge_attr=edge_att_torch, pos=pseudo_torch) 129 | # data.pcorr = pcorr_torch 130 | 131 | data, slices = split(data, batch_torch) 132 | 133 | return data, slices 134 | 135 | 136 | def read_sigle_data(data_dir,filename,use_gdc =False): 137 | 138 | temp = dd.io.load(osp.join(data_dir, filename)) 139 | 140 | 141 | # # 获取时间序列 142 | # timeseries = get_timeseries_by_ID(filename[:5]) 143 | 144 | # read edge and edge attribute 145 | pcorr = np.abs(temp['pcorr'][()]) 146 | 147 | num_nodes = pcorr.shape[0] 148 | G = from_numpy_matrix(pcorr) 149 | A = nx.to_scipy_sparse_matrix(G) 150 | adj = A.tocoo() 151 | edge_att = np.zeros(len(adj.row)) 152 | for i in range(len(adj.row)): 153 | edge_att[i] = pcorr[adj.row[i], adj.col[i]] 154 | 155 | edge_index = np.stack([adj.row, adj.col]) 156 | edge_index, edge_att = remove_self_loops(torch.from_numpy(edge_index), torch.from_numpy(edge_att)) 157 | edge_index = edge_index.long() 158 | edge_index, edge_att = coalesce(edge_index, edge_att, num_nodes, 159 | num_nodes) 160 | att = temp['corr'][()] 161 | label = temp['label'][()] 162 | 163 | att_torch = torch.from_numpy(att).float() 164 | y_torch = torch.from_numpy(np.array(label)).long() # classification 165 | 166 | data = Data(x=att_torch, edge_index=edge_index.long(), y=y_torch, edge_attr=edge_att) 167 | 168 | if use_gdc: 169 | ''' 170 | Implementation of https://papers.nips.cc/paper/2019/hash/23c894276a2c5a16470e6a31f4618d73-Abstract.html 171 | ''' 172 | data.edge_attr = data.edge_attr.squeeze() 173 | gdc = GDC(self_loop_weight=1, normalization_in='sym', 174 | normalization_out='col', 175 | diffusion_kwargs=dict(method='ppr', alpha=0.2), 176 | sparsification_kwargs=dict(method='topk', k=20, 177 | dim=0), exact=True) 178 | data = gdc(data) 179 | return data.edge_attr.data.numpy(),data.edge_index.data.numpy(),data.x.data.numpy(),data.y.data.item(),num_nodes 180 | 181 | else: 182 | # 增加了时间序列 183 | label = np.append(label,int(filename[:5])) 184 | return edge_att.data.numpy(),edge_index.data.numpy(),att,label,num_nodes 185 | 186 | 187 | 188 | if __name__ == "__main__": 189 | data_dir = '/home/azureuser/projects/net/data/ABIDE_pcp/cpac/filt_noglobal/raw' 190 | filename = '50346.h5' 191 | read_sigle_data(data_dir, filename) 192 | 193 | 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /imports/utils.py: -------------------------------------------------------------------------------- 1 | from scipy import stats 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | import torch 6 | from scipy.io import loadmat 7 | from sklearn.model_selection import StratifiedKFold 8 | from sklearn.model_selection import KFold 9 | 10 | 11 | def train_val_test_split(kfold = 5, fold = 0): 12 | n_sub = 1035 13 | id = list(range(n_sub)) 14 | 15 | 16 | import random 17 | random.seed(123) 18 | random.shuffle(id) 19 | 20 | kf = KFold(n_splits=kfold, random_state=123,shuffle = True) 21 | kf2 = KFold(n_splits=kfold-1, shuffle=True, random_state = 666) 22 | 23 | 24 | test_index = list() 25 | train_index = list() 26 | val_index = list() 27 | 28 | for tr,te in kf.split(np.array(id)): 29 | test_index.append(te) 30 | tr_id, val_id = list(kf2.split(tr))[0] 31 | train_index.append(tr[tr_id]) 32 | val_index.append(tr[val_id]) 33 | 34 | train_id = train_index[fold] 35 | test_id = test_index[fold] 36 | val_id = val_index[fold] 37 | 38 | return train_id,val_id,test_id 39 | 40 | def get_project_path(): 41 | """得到项目路径""" 42 | project_path = os.path.join( 43 | os.path.dirname(__file__), 44 | "..", 45 | ) 46 | return project_path 47 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | import yaml 4 | import torch 5 | import os 6 | import random 7 | import numpy as np 8 | import torch.backends.cudnn as cudnn 9 | 10 | 11 | from train import BasicTrain 12 | 13 | from model.model import PLSNet 14 | from dataloader import init_dataloader 15 | 16 | 17 | 18 | 19 | def main(args): 20 | with open(args.config_filename) as f: 21 | config = yaml.load(f, Loader=yaml.Loader) 22 | 23 | dataloaders, node_size, node_feature_size, timeseries_size = \ 24 | init_dataloader(config['data']) 25 | 26 | config['train']["seq_len"] = timeseries_size 27 | config['train']["node_size"] = node_size 28 | 29 | 30 | 31 | model = PLSNet(config['model'], node_size, 32 | node_feature_size, timeseries_size) 33 | use_train = BasicTrain 34 | 35 | 36 | 37 | 38 | optimizer = torch.optim.Adam( 39 | model.parameters(), lr=config['train']['lr'], 40 | weight_decay=config['train']['weight_decay']) 41 | opts = (optimizer,) 42 | 43 | loss_name = 'loss' 44 | if config['train']["group_loss"]: 45 | loss_name = f"{loss_name}_group_loss" 46 | if config['train']["sparsity_loss"]: 47 | loss_name = f"{loss_name}_sparsity_loss" 48 | 49 | 50 | save_folder_name = Path(config['train']['log_folder'])/Path(config['model']['type'])/Path( 51 | # date_time + 52 | f"{config['data']['dataset']}_{config['data']['atlas']}") 53 | 54 | train_process = use_train( 55 | config['train'], model, opts, dataloaders, save_folder_name) 56 | 57 | train_process.train() 58 | 59 | 60 | if __name__ == '__main__': 61 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--config_filename', default='setting/abide_PLSNet.yaml', type=str, 64 | help='Configuration filename for training the model.') 65 | parser.add_argument('--repeat_time', default=100, type=int) 66 | args = parser.parse_args() 67 | torch.cuda.set_device(0) 68 | # 控制随机性 69 | seed = 0 70 | random.seed(seed) 71 | np.random.seed(seed) 72 | if torch.cuda.is_available(): 73 | torch.cuda.manual_seed_all(seed) 74 | 75 | torch.manual_seed(seed) 76 | torch.cuda.manual_seed(seed) 77 | torch.cuda.manual_seed_all(seed) 78 | 79 | # cudnn.benchmark = False 80 | cudnn.deterministic = True 81 | 82 | for i in range(args.repeat_time): 83 | main(args) 84 | -------------------------------------------------------------------------------- /model/Encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | import torch 3 | 4 | 5 | class FullyConnectedOutput(torch.nn.Module): 6 | def __init__(self, embed_dim, input_dim): 7 | super().__init__() 8 | self.fc = torch.nn.Sequential( 9 | torch.nn.Linear(embed_dim, 32), 10 | torch.nn.LeakyReLU(negative_slope=0.2), 11 | torch.nn.Dropout(p=0.1), 12 | torch.nn.Linear(32, embed_dim), 13 | torch.nn.LeakyReLU(negative_slope=0.2), 14 | torch.nn.Dropout(p=0.1) 15 | ) 16 | 17 | self.norm = torch.nn.LayerNorm(normalized_shape=embed_dim, elementwise_affine=True) 18 | 19 | def forward(self, x): 20 | 21 | x = self.norm(x) 22 | 23 | 24 | # [b, 50, 32] -> [b, 50, 32] 25 | out = self.fc(x) 26 | 27 | return out 28 | 29 | 30 | 31 | def attention(Q, K, V): 32 | 33 | l = Q.shape[2] 34 | num_head = Q.shape[1] 35 | 36 | 37 | score = torch.matmul(Q, K.permute(0, 1, 3, 2)) 38 | 39 | 40 | score /= 8 ** 0.5 41 | 42 | score = torch.softmax(score, dim=-1) 43 | 44 | 45 | 46 | score = torch.matmul(score, V) 47 | 48 | score = score.permute(0, 2, 1, 3).reshape(-1, l, num_head * Q.shape[3]) 49 | 50 | return score 51 | 52 | 53 | class MultiHead(torch.nn.Module): 54 | def __init__(self, input_dim, num_head, embed_dim): 55 | super().__init__() 56 | self.fc_Q = torch.nn.Linear(input_dim, 32) 57 | self.fc_K = torch.nn.Linear(input_dim, 32) 58 | self.fc_V = torch.nn.Linear(input_dim, 32) 59 | 60 | self.num_head = num_head 61 | 62 | self.out_fc = torch.nn.Linear(32, embed_dim) 63 | 64 | self.norm = torch.nn.LayerNorm(normalized_shape=input_dim, elementwise_affine=True) 65 | self.dropout = torch.nn.Dropout(p=0.1) 66 | 67 | def forward(self, Q, K, V): 68 | 69 | # Q, K, V = [b, 50, 32] 70 | b = Q.shape[0] 71 | len = Q.shape[1] 72 | 73 | Q = self.norm(Q) 74 | K = self.norm(K) 75 | V = self.norm(V) 76 | 77 | K = self.fc_K(K) 78 | V = self.fc_V(V) 79 | Q = self.fc_Q(Q) 80 | 81 | 82 | Q = Q.reshape(b, len, self.num_head, -1).permute(0, 2, 1, 3) 83 | K = K.reshape(b, len, self.num_head, -1).permute(0, 2, 1, 3) 84 | V = V.reshape(b, len, self.num_head, -1).permute(0, 2, 1, 3) 85 | 86 | score = attention(Q, K, V) 87 | 88 | score = self.dropout(self.out_fc(score)) 89 | 90 | return score 91 | 92 | 93 | class EncoderLayer(torch.nn.Module): 94 | def __init__(self, input_dim, num_head, embed_dim): 95 | super(EncoderLayer, self).__init__() 96 | self.mh = MultiHead(input_dim, num_head, embed_dim) 97 | self.fc = FullyConnectedOutput(embed_dim, input_dim) 98 | 99 | def forward(self, x): 100 | score = self.mh(x, x, x) 101 | out = self.fc(score) 102 | 103 | return out 104 | 105 | 106 | class Encoder(torch.nn.Module): 107 | def __init__(self, input_dim, num_head, embed_dim): 108 | super(Encoder, self).__init__() 109 | self.layer = EncoderLayer(input_dim, num_head, embed_dim) 110 | 111 | def forward(self, x): 112 | x = self.layer(x) 113 | 114 | return x 115 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .GSL import BrainGSLModel, TSConstruction 2 | from .model import FBNETGEN, GNNPredictor, SeqenceModel, BrainNetCNN -------------------------------------------------------------------------------- /model/__pycache__/Encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/model/__pycache__/Encoder.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/GAU.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/model/__pycache__/GAU.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/model/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import Conv1d, MaxPool1d, Linear, GRU 6 | 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | from model.Encoder import Encoder 10 | 11 | 12 | class Embed2GraphByProduct(nn.Module): 13 | 14 | def __init__(self, input_dim, roi_num=264): 15 | super().__init__() 16 | 17 | def forward(self, x): 18 | 19 | m = torch.einsum('ijk,ipk->ijp', x, x) 20 | 21 | m = torch.unsqueeze(m, -1) 22 | 23 | return m 24 | 25 | class GCNPredictor(nn.Module): 26 | 27 | def __init__(self, node_input_dim, roi_num=360): 28 | super().__init__() 29 | inner_dim = roi_num 30 | self.roi_num = roi_num 31 | self.gcn = nn.Sequential( 32 | nn.Linear(node_input_dim, inner_dim), 33 | nn.LeakyReLU(negative_slope=0.2), 34 | Linear(inner_dim, inner_dim) 35 | ) 36 | self.bn1 = torch.nn.BatchNorm1d(inner_dim) 37 | 38 | self.gcn1 = nn.Sequential( 39 | nn.Linear(inner_dim, inner_dim), 40 | nn.LeakyReLU(negative_slope=0.2), 41 | ) 42 | self.bn2 = torch.nn.BatchNorm1d(inner_dim) 43 | self.gcn2 = nn.Sequential( 44 | nn.Linear(inner_dim, 64), 45 | nn.LeakyReLU(negative_slope=0.2), 46 | nn.Linear(64, 8), 47 | nn.LeakyReLU(negative_slope=0.2), 48 | ) 49 | self.bn3 = torch.nn.BatchNorm1d(inner_dim) 50 | 51 | 52 | self.fcn = nn.Sequential( 53 | nn.Linear(int(8 * int(roi_num * 0.7)), 256), 54 | # nn.Linear(8 * roi_num, 256), 55 | nn.LeakyReLU(negative_slope=0.2), 56 | nn.Linear(256, 32), 57 | nn.LeakyReLU(negative_slope=0.2), 58 | nn.Linear(32, 2) 59 | ) 60 | self.norm = torch.nn.LayerNorm(normalized_shape=roi_num, elementwise_affine=True) 61 | self.weight = torch.nn.Parameter(torch.Tensor(1, 8)) 62 | 63 | self.softmax = nn.Sigmoid() 64 | 65 | 66 | def forward(self, m, node_feature): 67 | bz = m.shape[0] 68 | 69 | x = torch.einsum('ijk,ijp->ijp', m, node_feature) 70 | 71 | x = self.gcn(x) 72 | 73 | x = x.reshape((bz*self.roi_num, -1)) 74 | x = self.bn1(x) 75 | x = x.reshape((bz, self.roi_num, -1)) 76 | 77 | x = torch.einsum('ijk,ijp->ijp', m, x) 78 | 79 | x = self.gcn1(x) 80 | 81 | x = x.reshape((bz*self.roi_num, -1)) 82 | x = self.bn2(x) 83 | x = x.reshape((bz, self.roi_num, -1)) 84 | 85 | x = torch.einsum('ijk,ijp->ijp', m, x) 86 | 87 | x = self.gcn2(x) 88 | 89 | x = self.bn3(x) 90 | 91 | score = (x * self.weight).sum(dim=-1) 92 | 93 | # score = self.norm(score) 94 | score = self.softmax(score) 95 | sc = score 96 | 97 | _, idx = score.sort(dim=-1) 98 | _, rank = idx.sort(dim=-1) 99 | 100 | l = int(m.shape[1] * 0.7) 101 | x_p = torch.empty(bz, l, 8) 102 | 103 | for i in range(x.shape[0]): 104 | x_p[i] = x[i, rank[i, :l], :] 105 | 106 | 107 | x = x_p.view(bz,-1).to(device) 108 | # x = x.view(bz,-1).to(device) 109 | 110 | 111 | # return self.fcn(x), x 112 | return self.fcn(x), sc 113 | 114 | 115 | class PLSNet(nn.Module): 116 | 117 | def __init__(self, model_config, roi_num=360, node_feature_dim=360, time_series=512): 118 | super().__init__() 119 | 120 | self.extract = Encoder(input_dim=time_series, num_head=4, embed_dim=model_config['embedding_size']) 121 | 122 | 123 | 124 | 125 | self.emb2graph = Embed2GraphByProduct( 126 | model_config['embedding_size'], roi_num=roi_num) 127 | 128 | self.predictor = GCNPredictor(node_feature_dim, roi_num=roi_num) 129 | self.fc_q = nn.Sequential(nn.Linear(in_features=model_config['embedding_size'], out_features=roi_num), 130 | nn.LeakyReLU(negative_slope=0.2)) 131 | 132 | 133 | self.fc_p = nn.Sequential(nn.Linear(in_features=roi_num, out_features=roi_num), 134 | nn.LeakyReLU(negative_slope=0.2)) 135 | 136 | def forward(self, t, nodes, pseudo): 137 | x = self.extract(t) 138 | m = F.softmax(x, dim=-1) 139 | m = self.emb2graph(m) 140 | 141 | m = m[:, :, :, 0] 142 | 143 | bz, _, _ = m.shape 144 | 145 | edge_variance = torch.mean(torch.var(m.reshape((bz, -1)), dim=1)) 146 | 147 | pseudo = self.fc_p(pseudo) 148 | nodes = nodes + pseudo 149 | return self.predictor(m, nodes), m, edge_variance 150 | 151 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |
2 |

PLSNet: Position-aware GCN-based autism spectrum disorder diagnosis via FC learning and ROIs sifting

3 | 4 | 5 | [Yibin Wang](https://codegoat24.github.io)\*, Haixia Long, Qianwei Zhou, Tao Bo, [Jianwei Zheng](https://zhengjianwei2.github.io/)† 6 | 7 | (†corresponding author) 8 | 9 | [Zhejiang University of Technology] 10 | 11 | Accepted by _**Computer in Biology and Medicine**_ 12 | 13 |
14 | 15 | ## 📖 Abstract 16 | Brain function connectivity, derived from functional magnetic resonance imaging (fMRI), has enjoyed high popularity in the studies of Autism Spectrum Disorder (ASD) diagnosis. Albeit rapid progress has been made, most studies still suffer from several knotty issues: (1) the hardship of modeling the sophisticated brain neuronal connectivity; (2) the mismatch of identically graph node setup to the variations of different brain regions; (3) the dimensionality explosion resulted from excessive voxels in each fMRI sample; (4) the poor interpretability giving rise to unpersuasive diagnosis. To ameliorate these issues, we propose a position-aware graph-convolution-network-based model, namely PLSNet, with superior accuracy and compelling built-in interpretability for ASD diagnosis. Specifically, a time-series encoder is designed for context-rich feature extraction, followed by a function connectivity generator to model the correlation with long range dependencies. In addition, to discriminate the brain nodes with different locations, the position embedding technique is adopted, giving a unique identity to each graph region. We then embed a rarefying method to sift the salient nodes during message diffusion, which would also benefit the reduction of the dimensionality complexity. Extensive experiments conducted on Autism Brain Imaging Data Exchange demonstrate that our PLSNet achieves state-of-the-art performance. Notably, on CC200 atlas, PLSNet reaches an accuracy of 76.4% and a specificity of 78.6%, overwhelming the previous state-of-the-art with 2.5% and 6.5% under five-fold cross-validation policy. Moreover, the most salient brain regions predicted by PLSNet are closely consistent with the theoretical knowledge in the medical domain, providing potential biomarkers for ASD clinical diagnosis. 17 | 18 | ![PLSNet](./PLSNet.png) 19 | 20 | ![Biomarkers](./Biomarkers.png) 21 | 22 | ## 🔧 Dataset 23 | 24 | 25 | Please follow the [instruction](util/abide/readme.md) to download and process **ABIDE** dataset. 26 | 27 | ## 🔥 Run 28 | 29 | ```bash 30 | python main.py --config_filename setting/abide_RGTNet.yaml 31 | ``` 32 | 33 | ### Hyperparameters 34 | 35 | All hyperparameters can be tuned in setting files. 36 | 37 | ```yaml 38 | model: 39 | type: PLSNet 40 | extractor_type: attention 41 | embedding_size: 8 42 | window_size: 4 43 | 44 | dropout: 0.5 45 | 46 | 47 | 48 | train: 49 | lr: 1.0e-4 50 | weight_decay: 1.0e-4 51 | epochs: 500 52 | pool_ratio: 0.7 53 | optimizer: adam 54 | stepsize: 200 55 | 56 | group_loss: true 57 | sparsity_loss: true 58 | sparsity_loss_weight: 0.5e-4 59 | log_folder: result 60 | 61 | # uniform or pearson 62 | pure_gnn_graph: pearson 63 | ``` 64 | 65 | ## ⏬ Download the Pre-trained Models 66 | We provide models for PLSNet_AAL and PLSNet_CC200. 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 |
atlasacc.%sen.%spe.%url of model
0AAL72.471.671.3baidu disk (code: 7fig)
1CC20076.474.778.6baidu disk (code: pmbz)
99 | 100 | 101 | ## 🖊️ BibTeX 102 | If you use this repository in your research, consider citing it using the following Bibtex entry: 103 | 104 | ``` 105 | @article{wang2023plsnet, 106 | title={PLSNet: Position-aware GCN-based autism spectrum disorder diagnosis via FC learning and ROIs sifting}, 107 | author={Wang, Yibin and Long, Haixia and Zhou, Qianwei and Bo, Tao and Zheng, Jianwei}, 108 | journal={Computers in Biology and Medicine}, 109 | pages={107184}, 110 | year={2023}, 111 | volume={163}, 112 | publisher={Elsevier} 113 | } 114 | ``` 115 | 116 | ## 📧 Contact 117 | 118 | If you have any technical comments or questions, please open a new issue or feel free to contact [Yibin Wang](https://codegoat24.github.io). 119 | -------------------------------------------------------------------------------- /setting/abide_PLSNet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: ABIDE 3 | atlas: aal 4 | batch_size: 16 5 | test_batch_size: 16 6 | val_batch_size: 16 7 | train_set: 0.7 8 | val_set: 0.1 9 | fold: 0 10 | # time_seires: /data/CodeGoat24/FBNETGEN/ABIDE_pcp/abide.npy 11 | # time_seires: /data/CodeGoat24/FBNETGEN_ho/ABIDE_pcp/abide.npy 12 | time_seires: /data/CodeGoat24/FBNETGEN_AAL/ABIDE_pcp/abide.npy 13 | # time_seires: /data/CodeGoat24/FBNETGEN_cc400/ABIDE_pcp/abide.npy 14 | 15 | 16 | 17 | 18 | model: 19 | type: PLSNet 20 | extractor_type: attention 21 | embedding_size: 8 22 | window_size: 4 23 | 24 | cnn_pool_size: 16 25 | 26 | num_gru_layers: 4 27 | 28 | dropout: 0.5 29 | 30 | 31 | 32 | train: 33 | lr: 1.0e-4 34 | weight_decay: 1.0e-4 35 | epochs: 500 36 | pool_ratio: 0.7 37 | optimizer: adam 38 | stepsize: 200 39 | 40 | group_loss: true 41 | sparsity_loss: true 42 | sparsity_loss_weight: 0.5e-4 43 | log_folder: result 44 | 45 | # uniform or pearson 46 | pure_gnn_graph: pearson -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from datetime import datetime 6 | from util import Logger, accuracy, TotalMeter 7 | import numpy as np 8 | from pathlib import Path 9 | import torch.nn.functional as F 10 | 11 | from sklearn.metrics import roc_auc_score, confusion_matrix 12 | from sklearn.metrics import precision_recall_fscore_support 13 | from util.prepossess import mixup_criterion, mixup_data 14 | from util.loss import mixup_cluster_loss, topk_loss 15 | 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1" 19 | 20 | 21 | class BasicTrain: 22 | 23 | def __init__(self, train_config, model, optimizers, dataloaders, log_folder) -> None: 24 | self.logger = Logger() 25 | self.model = model.to(device) 26 | self.train_dataloader, self.val_dataloader, self.test_dataloader = dataloaders 27 | self.epochs = train_config['epochs'] 28 | self.optimizers = optimizers 29 | self.best_acc = 0 30 | self.best_model = None 31 | self.best_acc_val = 0 32 | self.best_auc_val = 0 33 | self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean') 34 | self.pool_ratio = train_config['pool_ratio'] 35 | self.group_loss = train_config['group_loss'] 36 | 37 | self.sparsity_loss = train_config['sparsity_loss'] 38 | self.sparsity_loss_weight = train_config['sparsity_loss_weight'] 39 | 40 | self.save_path = log_folder 41 | 42 | self.save_learnable_graph = True 43 | 44 | self.init_meters() 45 | 46 | def init_meters(self): 47 | self.train_loss, self.val_loss, self.test_loss, self.train_accuracy,\ 48 | self.val_accuracy, self.test_accuracy, self.edges_num = [ 49 | TotalMeter() for _ in range(7)] 50 | 51 | self.loss1, self.loss2, self.loss3 = [TotalMeter() for _ in range(3)] 52 | 53 | def reset_meters(self): 54 | for meter in [self.train_accuracy, self.val_accuracy, self.test_accuracy, 55 | self.train_loss, self.val_loss, self.test_loss, self.edges_num, 56 | self.loss1, self.loss2, self.loss3]: 57 | meter.reset() 58 | 59 | def train_per_epoch(self, optimizer): 60 | 61 | 62 | self.model.train() 63 | for data_in, pearson, label, pseudo in self.train_dataloader: 64 | 65 | label = label.long() 66 | 67 | data_in, pearson, label, pseudo = data_in.to( 68 | device), pearson.to(device), label.to(device), pseudo.to(device) 69 | 70 | inputs, nodes, targets_a, targets_b, lam = mixup_data( 71 | data_in, pearson, label, 1, device) 72 | 73 | [output, score], learnable_matrix, edge_variance = self.model(inputs, nodes, pseudo) 74 | 75 | loss = 2 * mixup_criterion( 76 | self.loss_fn, output, targets_a, targets_b, lam) 77 | 78 | if self.group_loss: 79 | loss += mixup_cluster_loss(learnable_matrix, 80 | targets_a, targets_b, lam) 81 | 82 | # loss += 0.001*topk_loss(score, self.pool_ratio) 83 | 84 | self.train_loss.update_with_weight(loss.item(), label.shape[0]) 85 | optimizer.zero_grad() 86 | loss.backward() 87 | optimizer.step() 88 | top1 = accuracy(output, label)[0] 89 | self.train_accuracy.update_with_weight(top1, label.shape[0]) 90 | self.edges_num.update_with_weight(edge_variance, label.shape[0]) 91 | 92 | def test_per_epoch(self, dataloader, loss_meter, acc_meter): 93 | labels = [] 94 | result = [] 95 | 96 | self.model.eval() 97 | 98 | for data_in, pearson, label, pseudo in dataloader: 99 | label = label.long() 100 | data_in, pearson, label, pseudo = data_in.to( 101 | device), pearson.to(device), label.to(device), pseudo.to(device) 102 | [output, score], _, _ = self.model(data_in, pearson, pseudo) 103 | loss = self.loss_fn(output, label) 104 | loss_meter.update_with_weight( 105 | loss.item(), label.shape[0]) 106 | top1 = accuracy(output, label)[0] 107 | acc_meter.update_with_weight(top1, label.shape[0]) 108 | result += F.softmax(output, dim=1)[:, 1].tolist() 109 | labels += label.tolist() 110 | 111 | auc = roc_auc_score(labels, result) 112 | 113 | result = np.array(result) 114 | result[result > 0.5] = 1 115 | result[result <= 0.5] = 0 116 | metric = precision_recall_fscore_support( 117 | labels, result, average='micro') 118 | con_matrix = confusion_matrix(labels, result) 119 | return [auc] + list(metric), con_matrix 120 | 121 | def generate_save_learnable_matrix(self): 122 | learable_matrixs = [] 123 | 124 | labels = [] 125 | 126 | for data_in, nodes, label, pseudo in self.test_dataloader: 127 | label = label.long() 128 | data_in, nodes, label, pseudo = data_in.to( 129 | device), nodes.to(device), label.to(device), pseudo.to(device) 130 | _, learable_matrix, _ = self.model(data_in, nodes, pseudo) 131 | 132 | learable_matrixs.append(learable_matrix.cpu().detach().numpy()) 133 | labels += label.tolist() 134 | 135 | self.save_path.mkdir(exist_ok=True, parents=True) 136 | np.save(self.save_path/"learnable_matrix.npy", {'matrix': np.vstack( 137 | learable_matrixs), "label": np.array(labels)}, allow_pickle=True) 138 | 139 | def save_result(self, results, txt): 140 | 141 | self.save_path.mkdir(exist_ok=True, parents=True) 142 | np.save(self.save_path/"training_process.npy", 143 | results, allow_pickle=True) 144 | with open(self.save_path / "training_info.txt", 'a', encoding='utf-8') as f: 145 | f.write(txt) 146 | torch.save(self.best_model.state_dict(), self.save_path/f"model_{self.best_acc}%.pt") 147 | 148 | 149 | 150 | def train(self): 151 | training_process = [] 152 | txt = '' 153 | for epoch in range(self.epochs): 154 | self.reset_meters() 155 | self.train_per_epoch(self.optimizers[0]) 156 | val_result, _ = self.test_per_epoch(self.val_dataloader, 157 | self.val_loss, self.val_accuracy) 158 | 159 | test_result, con_matrix = self.test_per_epoch(self.test_dataloader, 160 | self.test_loss, self.test_accuracy) 161 | 162 | if self.best_acc <= self.test_accuracy.avg: 163 | self.best_acc = self.test_accuracy.avg 164 | self.best_model = self.model 165 | 166 | if (con_matrix[0][0] + con_matrix[1][0]) != 0: 167 | SEN = con_matrix[0][0] / (con_matrix[0][0] + con_matrix[1][0]) 168 | else: 169 | SEN = 0 170 | 171 | if (con_matrix[1][1] + con_matrix[0][1]) != 0: 172 | SPE = con_matrix[1][1] / (con_matrix[1][1] + con_matrix[0][1]) 173 | else: 174 | SPE = 0 175 | 176 | self.logger.info(" | ".join([ 177 | f'Epoch[{epoch}/{self.epochs}]', 178 | f'Train Loss:{self.train_loss.avg: .3f}', 179 | f'Train Accuracy:{self.train_accuracy.avg: .3f}%', 180 | # f'Edges:{self.edges_num.avg: .3f}', 181 | # f'Test Loss:{self.test_loss.avg: .3f}', 182 | f'Val Accuracy:{self.val_accuracy.avg: .3f}%', 183 | f'Test Accuracy:{self.test_accuracy.avg: .3f}%', 184 | f'Val AUC:{val_result[0]:.2f}', 185 | f'Test AUC:{test_result[0]:.4f}', 186 | f'Test SEN:{SEN:.4f}', 187 | f'Test SPE:{SPE:.4f}' 188 | ])) 189 | 190 | txt += f'Epoch[{epoch}/{self.epochs}] '+f'Train Loss:{self.train_loss.avg: .3f} '+f'Train Accuracy:{self.train_accuracy.avg: .3f}% '+f'Val Accuracy:{self.val_accuracy.avg: .3f}% '+f'Test Accuracy:{self.test_accuracy.avg: .3f}% '+f'Val AUC:{val_result[0]:.3f} '+f'Test AUC:{test_result[0]:.4f}'+f'Test SEN:{SEN:.4f}'+f'Test SPE:{SPE:.4f}'+'\n' 191 | 192 | training_process.append([self.train_accuracy.avg, self.train_loss.avg, 193 | self.val_loss.avg, self.test_loss.avg] 194 | + val_result + test_result) 195 | now = datetime.now() 196 | date_time = now.strftime("%m-%d-%H-%M-%S") 197 | self.save_path = self.save_path/Path(f"{self.best_acc: .3f}%_{date_time}") 198 | if self.save_learnable_graph: 199 | self.generate_save_learnable_matrix() 200 | self.save_result(training_process, txt) 201 | 202 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger 2 | from .meter import AverageMeter, TotalMeter, accuracy 3 | -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/util/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/util/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/meter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/util/__pycache__/meter.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/prepossess.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/util/__pycache__/prepossess.cpython-37.pyc -------------------------------------------------------------------------------- /util/abide/01-fetch_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Mwiza Kunda 2 | # Copyright (C) 2017 Sarah Parisot , , Sofia Ira Ktena 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | ''' 18 | This script mainly refers to https://github.com/kundaMwiza/fMRI-site-adaptation/blob/master/fetch_data.py 19 | ''' 20 | 21 | from nilearn import datasets 22 | import argparse 23 | from preprocess_data import Reader 24 | import os 25 | import shutil 26 | import sys 27 | 28 | 29 | def str2bool(v): 30 | if isinstance(v, bool): 31 | return v 32 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 33 | return True 34 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 35 | return False 36 | else: 37 | raise argparse.ArgumentTypeError('Boolean value expected.') 38 | 39 | 40 | def main(args): 41 | print(args) 42 | 43 | root_folder = args.root_path 44 | data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/') 45 | if not os.path.exists(data_folder): 46 | os.makedirs(data_folder) 47 | 48 | pipeline = args.pipeline 49 | atlas = args.atlas 50 | download = args.download 51 | 52 | # Files to fetch 53 | 54 | files = ['rois_' + atlas] 55 | 56 | filemapping = {'func_preproc': 'func_preproc.nii.gz', 57 | files[0]: files[0] + '.1D'} 58 | 59 | 60 | # Download database files 61 | # if download == True: 62 | # abide = datasets.fetch_abide_pcp(data_dir=root_folder, pipeline=pipeline, 63 | # band_pass_filtering=True, global_signal_regression=False, derivatives=files, 64 | # quality_checked=False) 65 | reader = Reader(root_folder, args.id_file_path) 66 | subject_IDs = reader.get_ids() #changed path to data path 67 | subject_IDs = subject_IDs.tolist() 68 | 69 | # Create a folder for each subject 70 | for s, fname in zip(subject_IDs, reader.fetch_filenames(subject_IDs, files[0], atlas)): 71 | subject_folder = os.path.join(data_folder, s) 72 | if not os.path.exists(subject_folder): 73 | os.mkdir(subject_folder) 74 | 75 | # Get the base filename for each subject 76 | base = fname.split(files[0])[0] 77 | 78 | # Move each subject file to the subject folder 79 | for fl in files: 80 | if not os.path.exists(os.path.join(subject_folder, base + filemapping[fl])): 81 | shutil.move(base + filemapping[fl], subject_folder) 82 | 83 | time_series = reader.get_timeseries(subject_IDs, atlas) 84 | 85 | # Compute and save connectivity matrices 86 | reader.subject_connectivity(time_series, subject_IDs, atlas, 'correlation') 87 | reader.subject_connectivity(time_series, subject_IDs, atlas, 'partial correlation') 88 | print("done!") 89 | 90 | 91 | if __name__ == '__main__': 92 | parser = argparse.ArgumentParser(description='Download ABIDE data and compute functional connectivity matrices') 93 | parser.add_argument('--pipeline', default='cpac', type=str, 94 | help='Pipeline to preprocess ABIDE data. Available options are ccs, cpac, dparsf and niak.' 95 | ' default: cpac.') 96 | parser.add_argument('--atlas', default='ho', 97 | help='Brain parcellation atlas. Options: ho, cc200 and cc400, default: cc200.') 98 | parser.add_argument('--download', default=True, type=str2bool, 99 | help='Dowload data or just compute functional connectivity. default: True') 100 | parser.add_argument('--root_path', default="/data/CodeGoat24/FBNETGEN_ho/", type=str, help='The path of the folder containing the dataset folder.') 101 | parser.add_argument('--id_file_path', default="subject_IDs.txt", type=str, help='The path to subject_IDs.txt.') 102 | args = parser.parse_args() 103 | main(args) -------------------------------------------------------------------------------- /util/abide/02-process_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Mwiza Kunda 2 | # Modified by Xuan Kan 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | 17 | import sys 18 | import argparse 19 | import pandas as pd 20 | import numpy as np 21 | from preprocess_data import Reader 22 | import deepdish as dd 23 | import warnings 24 | import os 25 | 26 | warnings.filterwarnings("ignore") 27 | 28 | 29 | 30 | # Process boolean command line arguments 31 | def str2bool(v): 32 | if isinstance(v, bool): 33 | return v 34 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 35 | return True 36 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 37 | return False 38 | else: 39 | raise argparse.ArgumentTypeError('Boolean value expected.') 40 | 41 | 42 | def main(args): 43 | 44 | print('Arguments: \n', args) 45 | 46 | data_folder = os.path.join(args.root_path, 'ABIDE_pcp/cpac/filt_noglobal/') 47 | 48 | 49 | params = dict() 50 | 51 | params['seed'] = args.seed # seed for random initialisation 52 | 53 | # Algorithm choice 54 | params['atlas'] = args.atlas # Atlas for network construction 55 | atlas = args.atlas # Atlas for network construction (node definition) 56 | 57 | reader = Reader(args.root_path, args.id_file_path) 58 | # Get subject IDs and class labels 59 | subject_IDs = reader.get_ids() 60 | labels = reader.get_subject_score(subject_IDs, score='DX_GROUP') 61 | 62 | # Number of subjects and classes for binary classification 63 | num_classes = args.nclass 64 | num_subjects = len(subject_IDs) 65 | params['n_subjects'] = num_subjects 66 | 67 | # Initialise variables for class labels and acquisition sites 68 | # 1 is autism, 2 is control 69 | y_data = np.zeros([num_subjects, num_classes]) # n x 2 70 | y = np.zeros([num_subjects, 1]) # n x 1 71 | 72 | # Get class labels for all subjects 73 | for i in range(num_subjects): 74 | y_data[i, int(labels[subject_IDs[i]]) - 1] = 1 75 | y[i] = int(labels[subject_IDs[i]]) 76 | 77 | 78 | 79 | # Compute feature vectors (vectorised connectivity networks) 80 | fea_corr = reader.get_networks(subject_IDs, iter_no='', kind='correlation', atlas_name=atlas) #(1035, 200, 200) 81 | fea_pcorr = reader.get_networks(subject_IDs, iter_no='', kind='partial correlation', atlas_name=atlas) #(1035, 200, 200) 82 | 83 | if not os.path.exists(os.path.join(data_folder,'raw')): 84 | os.makedirs(os.path.join(data_folder,'raw')) 85 | for i, subject in enumerate(subject_IDs): 86 | dd.io.save(os.path.join(data_folder,'raw',subject+'.h5'),{'corr':fea_corr[i],'pcorr':fea_pcorr[i],'label':(y[i]-1)}) 87 | 88 | if __name__ == '__main__': 89 | parser = argparse.ArgumentParser(description='Classification of the ABIDE dataset using a Ridge classifier. ' 90 | 'MIDA is used to minimize the distribution mismatch between ABIDE sites') 91 | parser.add_argument('--atlas', default='ho', 92 | help='Atlas for network construction (node definition) options: ho, cc200, cc400, default: cc200.') 93 | parser.add_argument('--seed', default=123, type=int, help='Seed for random initialisation. default: 1234.') 94 | parser.add_argument('--nclass', default=2, type=int, help='Number of classes. default:2') 95 | parser.add_argument('--root_path', default="/data/CodeGoat24/FBNETGEN_ho/", type=str, help='The path of the folder containing the dataset folder.') 96 | parser.add_argument('--id_file_path', default="subject_IDs.txt", type=str, help='The path to subject_IDs.txt.') 97 | 98 | 99 | args = parser.parse_args() 100 | main(args) 101 | -------------------------------------------------------------------------------- /util/abide/03-generate_abide_dataset.py: -------------------------------------------------------------------------------- 1 | import deepdish as dd 2 | import os.path as osp 3 | import os 4 | import numpy as np 5 | import argparse 6 | from pathlib import Path 7 | import pandas as pd 8 | 9 | 10 | def main(args): 11 | root_path = '/data/CodeGoat24/FBNETGEN_ho' 12 | data_dir = os.path.join(root_path, 'ABIDE_pcp/cpac/filt_noglobal/raw') 13 | timeseires = os.path.join(root_path, 'ABIDE_pcp/cpac/filt_noglobal/') 14 | 15 | meta_file = os.path.join(root_path, 'ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv') 16 | 17 | meta_file = pd.read_csv(meta_file, header=0) 18 | 19 | id2site = meta_file[["subject", "SITE_ID"]] 20 | 21 | # pandas to map 22 | id2site = id2site.set_index("subject") 23 | id2site = id2site.to_dict()['SITE_ID'] 24 | 25 | times = [] 26 | 27 | labels = [] 28 | pcorrs = [] 29 | 30 | corrs = [] 31 | 32 | site_list = [] 33 | 34 | for f in os.listdir(data_dir): 35 | if osp.isfile(osp.join(data_dir, f)): 36 | fname = f.split('.')[0] 37 | site = id2site[int(fname)] 38 | 39 | 40 | files = os.listdir(osp.join(timeseires, fname)) 41 | 42 | file = list(filter(lambda x: x.endswith("1D"), files))[0] 43 | 44 | time = np.loadtxt(osp.join(timeseires, fname, file), skiprows=0).T 45 | 46 | if time.shape[1] < 100: 47 | continue 48 | 49 | temp = dd.io.load(osp.join(data_dir, f)) 50 | pcorr = temp['pcorr'][()] 51 | 52 | pcorr[pcorr == float('inf')] = 0 53 | 54 | att = temp['corr'][()] 55 | 56 | att[att == float('inf')] = 0 57 | 58 | label = temp['label'] 59 | 60 | times.append(time[:,:100]) 61 | labels.append(label[0]) 62 | corrs.append(att) 63 | pcorrs.append(pcorr) 64 | site_list.append(site) 65 | 66 | np.save(os.path.join(root_path, 'ABIDE_pcp/abide.npy'), {'timeseires': np.array(times), "label": np.array(labels),"corr": np.array(corrs),"pcorr": np.array(pcorrs), 'site': np.array(site_list)}) 67 | 68 | 69 | if __name__ == '__main__': 70 | parser = argparse.ArgumentParser(description='Generate the final dataset') 71 | parser.add_argument('--root_path', default='', type=str, help='The path of the folder containing the dataset folder.') 72 | args = parser.parse_args() 73 | main(args) 74 | -------------------------------------------------------------------------------- /util/abide/__pycache__/preprocess_data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/PLSNet/734442dd57a7374072c8a891df8c0a05776f4eda/util/abide/__pycache__/preprocess_data.cpython-37.pyc -------------------------------------------------------------------------------- /util/abide/preprocess_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Mwiza Kunda 2 | # Copyright (C) 2017 Sarah Parisot , Sofia Ira Ktena 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implcd ied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | 18 | import os 19 | import warnings 20 | import glob 21 | import csv 22 | import numpy as np 23 | import scipy.io as sio 24 | from nilearn import connectome 25 | import pandas as pd 26 | from scipy.spatial import distance 27 | from scipy import signal 28 | from sklearn.compose import ColumnTransformer 29 | from sklearn.preprocessing import Normalizer 30 | from sklearn.preprocessing import OrdinalEncoder 31 | from sklearn.preprocessing import OneHotEncoder 32 | from sklearn.preprocessing import StandardScaler 33 | warnings.filterwarnings("ignore") 34 | 35 | # Input data variables 36 | 37 | 38 | class Reader: 39 | 40 | def __init__(self, root_path, id_file_path=None) -> None: 41 | 42 | root_folder = root_path 43 | self.data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal') 44 | self.phenotype = os.path.join(root_folder, 'ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv') 45 | self.id_file = id_file_path 46 | 47 | 48 | def fetch_filenames(self, subject_IDs, file_type, atlas): 49 | """ 50 | subject_list : list of short subject IDs in string format 51 | file_type : must be one of the available file types 52 | filemapping : resulting file name format 53 | returns: 54 | filenames : list of filetypes (same length as subject_list) 55 | """ 56 | 57 | filemapping = {'func_preproc': '_func_preproc.nii.gz', 58 | 'rois_' + atlas: '_rois_' + atlas + '.1D'} 59 | # The list to be filled 60 | filenames = [] 61 | 62 | # Fill list with requested file paths 63 | for i in range(len(subject_IDs)): 64 | os.chdir(self.data_folder) 65 | find_files = glob.glob('*' + subject_IDs[i] + filemapping[file_type]) 66 | if len(find_files) > 0: 67 | filenames.append(find_files[0]) 68 | else: 69 | if os.path.isdir(self.data_folder + '/' + subject_IDs[i]): 70 | os.chdir(self.data_folder + '/' + subject_IDs[i]) 71 | filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0]) 72 | else: 73 | filenames.append('N/A') 74 | return filenames 75 | 76 | 77 | # Get timeseries arrays for list of subjects 78 | def get_timeseries(self, subject_list, atlas_name, silence=False): 79 | """ 80 | subject_list : list of short subject IDs in string format 81 | atlas_name : the atlas based on which the timeseries are generated e.g. aal, cc200 82 | returns: 83 | time_series : list of timeseries arrays, each of shape (timepoints x regions) 84 | """ 85 | 86 | timeseries = [] 87 | for i in range(len(subject_list)): 88 | subject_folder = os.path.join(self.data_folder, subject_list[i]) 89 | ro_file = [f for f in os.listdir(subject_folder) if f.endswith('_rois_' + atlas_name + '.1D')] 90 | fl = os.path.join(subject_folder, ro_file[0]) 91 | if silence != True: 92 | print("Reading timeseries file %s" % fl) 93 | timeseries.append(np.loadtxt(fl, skiprows=0)) 94 | 95 | return timeseries 96 | 97 | 98 | # compute connectivity matrices 99 | def subject_connectivity(self, timeseries, subjects, atlas_name, kind, iter_no='', seed=1234, 100 | n_subjects='', save=True, save_path=None): 101 | """ 102 | timeseries : timeseries table for subject (timepoints x regions) 103 | subjects : subject IDs 104 | atlas_name : name of the parcellation atlas used 105 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation 106 | iter_no : tangent connectivity iteration number for cross validation evaluation 107 | save : save the connectivity matrix to a file 108 | save_path : specify path to save the matrix if different from subject folder 109 | returns: 110 | connectivity : connectivity matrix (regions x regions) 111 | """ 112 | 113 | if kind in ['TPE', 'TE', 'correlation','partial correlation']: 114 | if kind not in ['TPE', 'TE']: 115 | conn_measure = connectome.ConnectivityMeasure(kind=kind) 116 | connectivity = conn_measure.fit_transform(timeseries) 117 | else: 118 | if kind == 'TPE': 119 | conn_measure = connectome.ConnectivityMeasure(kind='correlation') 120 | conn_mat = conn_measure.fit_transform(timeseries) 121 | conn_measure = connectome.ConnectivityMeasure(kind='tangent') 122 | connectivity_fit = conn_measure.fit(conn_mat) 123 | connectivity = connectivity_fit.transform(conn_mat) 124 | else: 125 | conn_measure = connectome.ConnectivityMeasure(kind='tangent') 126 | connectivity_fit = conn_measure.fit(timeseries) 127 | connectivity = connectivity_fit.transform(timeseries) 128 | 129 | if save: 130 | if not save_path: 131 | save_path = self.data_folder 132 | if kind not in ['TPE', 'TE']: 133 | for i, subj_id in enumerate(subjects): 134 | subject_file = os.path.join(save_path, subj_id, 135 | subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '.mat') 136 | sio.savemat(subject_file, {'connectivity': connectivity[i]}) 137 | return connectivity 138 | else: 139 | for i, subj_id in enumerate(subjects): 140 | subject_file = os.path.join(save_path, subj_id, 141 | subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '_' + str( 142 | iter_no) + '_' + str(seed) + '_' + validation_ext + str( 143 | n_subjects) + '.mat') 144 | sio.savemat(subject_file, {'connectivity': connectivity[i]}) 145 | return connectivity_fit 146 | 147 | 148 | # Get the list of subject IDs 149 | 150 | def get_ids(self, num_subjects=None): 151 | """ 152 | return: 153 | subject_IDs : list of all subject IDs 154 | """ 155 | 156 | subject_IDs = np.genfromtxt(self.id_file, dtype=str) 157 | 158 | if num_subjects is not None: 159 | subject_IDs = subject_IDs[:num_subjects] 160 | 161 | return subject_IDs 162 | 163 | 164 | # Get phenotype values for a list of subjects 165 | def get_subject_score(self, subject_list, score): 166 | scores_dict = {} 167 | 168 | with open(self.phenotype) as csv_file: 169 | reader = csv.DictReader(csv_file) 170 | for row in reader: 171 | if row['SUB_ID'] in subject_list: 172 | if score == 'HANDEDNESS_CATEGORY': 173 | if (row[score].strip() == '-9999') or (row[score].strip() == ''): 174 | scores_dict[row['SUB_ID']] = 'R' 175 | elif row[score] == 'Mixed': 176 | scores_dict[row['SUB_ID']] = 'Ambi' 177 | elif row[score] == 'L->R': 178 | scores_dict[row['SUB_ID']] = 'Ambi' 179 | else: 180 | scores_dict[row['SUB_ID']] = row[score] 181 | elif (score == 'FIQ' or score == 'PIQ' or score == 'VIQ'): 182 | if (row[score].strip() == '-9999') or (row[score].strip() == ''): 183 | scores_dict[row['SUB_ID']] = 100 184 | else: 185 | scores_dict[row['SUB_ID']] = float(row[score]) 186 | 187 | else: 188 | scores_dict[row['SUB_ID']] = row[score] 189 | 190 | return scores_dict 191 | 192 | 193 | # preprocess phenotypes. Categorical -> ordinal representation 194 | @staticmethod 195 | def preprocess_phenotypes(pheno_ft, params): 196 | if params['model'] == 'MIDA': 197 | ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2])], remainder='passthrough') 198 | else: 199 | ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2, 3])], remainder='passthrough') 200 | 201 | pheno_ft = ct.fit_transform(pheno_ft) 202 | pheno_ft = pheno_ft.astype('float32') 203 | 204 | return (pheno_ft) 205 | 206 | 207 | # create phenotype feature vector to concatenate with fmri feature vectors 208 | @staticmethod 209 | def phenotype_ft_vector(pheno_ft, num_subjects, params): 210 | gender = pheno_ft[:, 0] 211 | if params['model'] == 'MIDA': 212 | eye = pheno_ft[:, 0] 213 | hand = pheno_ft[:, 2] 214 | age = pheno_ft[:, 3] 215 | fiq = pheno_ft[:, 4] 216 | else: 217 | eye = pheno_ft[:, 2] 218 | hand = pheno_ft[:, 3] 219 | age = pheno_ft[:, 4] 220 | fiq = pheno_ft[:, 5] 221 | 222 | phenotype_ft = np.zeros((num_subjects, 4)) 223 | phenotype_ft_eye = np.zeros((num_subjects, 2)) 224 | phenotype_ft_hand = np.zeros((num_subjects, 3)) 225 | 226 | for i in range(num_subjects): 227 | phenotype_ft[i, int(gender[i])] = 1 228 | phenotype_ft[i, -2] = age[i] 229 | phenotype_ft[i, -1] = fiq[i] 230 | phenotype_ft_eye[i, int(eye[i])] = 1 231 | phenotype_ft_hand[i, int(hand[i])] = 1 232 | 233 | if params['model'] == 'MIDA': 234 | phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand], axis=1) 235 | else: 236 | phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand, phenotype_ft_eye], axis=1) 237 | 238 | return phenotype_ft 239 | 240 | 241 | # Load precomputed fMRI connectivity networks 242 | def get_networks(self, subject_list, kind, iter_no='', seed=1234, n_subjects='', atlas_name="aal", 243 | variable='connectivity'): 244 | """ 245 | subject_list : list of subject IDs 246 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation 247 | atlas_name : name of the parcellation atlas used 248 | variable : variable name in the .mat file that has been used to save the precomputed networks 249 | return: 250 | matrix : feature matrix of connectivity networks (num_subjects x network_size) 251 | """ 252 | 253 | all_networks = [] 254 | for subject in subject_list: 255 | if len(kind.split()) == 2: 256 | kind = '_'.join(kind.split()) 257 | fl = os.path.join(self.data_folder, subject, 258 | subject + "_" + atlas_name + "_" + kind.replace(' ', '_') + ".mat") 259 | 260 | 261 | matrix = sio.loadmat(fl)[variable] 262 | all_networks.append(matrix) 263 | 264 | if kind in ['TE', 'TPE']: 265 | norm_networks = [mat for mat in all_networks] 266 | else: 267 | norm_networks = [np.arctanh(mat) for mat in all_networks] 268 | 269 | networks = np.stack(norm_networks) 270 | 271 | return networks 272 | 273 | -------------------------------------------------------------------------------- /util/abide/readme.md: -------------------------------------------------------------------------------- 1 | # ABIDE 2 | 3 | These scripts will download and preprocess ABIDE dataset. 4 | 5 | ## Usage 6 | 7 | ```bash 8 | cd util/abide/ 9 | 10 | # If you meet time-out error, execute this command repeatly. The script can continue to download from the last failed file. 11 | python 01-fetch_data.py --root_path /path/to/the/save/folder/ --id_file_path subject_IDs.txt --download True 12 | 13 | # Generate correlation matrices. 14 | python 02-process_data.py --root_path /path/to/the/save/folder/ --id_file_path subject_IDs.txt 15 | 16 | # Generate the final dataset. 17 | python 03-generate_abide_dataset.py --root_path /path/to/the/save/folder/ 18 | ``` -------------------------------------------------------------------------------- /util/abide/subject_IDs.txt: -------------------------------------------------------------------------------- 1 | 50128 2 | 51203 3 | 50325 4 | 50117 5 | 50573 6 | 50741 7 | 50779 8 | 51009 9 | 50746 10 | 50574 11 | 50110 12 | 50322 13 | 51036 14 | 51204 15 | 50119 16 | 50126 17 | 50314 18 | 51490 19 | 50784 20 | 51464 21 | 51000 22 | 51038 23 | 50748 24 | 51235 25 | 51007 26 | 51463 27 | 50783 28 | 50777 29 | 50313 30 | 50121 31 | 51053 32 | 51261 33 | 50723 34 | 50511 35 | 51295 36 | 50347 37 | 50982 38 | 50976 39 | 51098 40 | 51292 41 | 50340 42 | 50516 43 | 50724 44 | 51266 45 | 51054 46 | 50186 47 | 50529 48 | 50985 49 | 50520 50 | 50376 51 | 50978 52 | 50144 53 | 51096 54 | 50382 55 | 51250 56 | 51062 57 | 50349 58 | 51065 59 | 50385 60 | 51257 61 | 50143 62 | 51091 63 | 50371 64 | 50527 65 | 51268 66 | 50188 67 | 50518 68 | 50749 69 | 51039 70 | 50776 71 | 50120 72 | 50312 73 | 51006 74 | 51234 75 | 50782 76 | 51462 77 | 50118 78 | 51465 79 | 50785 80 | 51001 81 | 50315 82 | 50127 83 | 51491 84 | 51008 85 | 50778 86 | 51205 87 | 50575 88 | 50747 89 | 50111 90 | 50129 91 | 50116 92 | 50324 93 | 50740 94 | 50572 95 | 51030 96 | 51202 97 | 50370 98 | 50142 99 | 51090 100 | 50526 101 | 51256 102 | 51064 103 | 50519 104 | 50189 105 | 51269 106 | 51063 107 | 50383 108 | 51251 109 | 50521 110 | 50145 111 | 51097 112 | 50979 113 | 50377 114 | 50348 115 | 51055 116 | 50187 117 | 51267 118 | 51293 119 | 50341 120 | 50725 121 | 51258 122 | 50984 123 | 50528 124 | 50970 125 | 50510 126 | 50722 127 | 51294 128 | 50346 129 | 51260 130 | 51052 131 | 51099 132 | 50977 133 | 50379 134 | 50983 135 | 50039 136 | 50496 137 | 51312 138 | 50234 139 | 50006 140 | 50650 141 | 50802 142 | 50668 143 | 51118 144 | 50657 145 | 50233 146 | 51127 147 | 51315 148 | 50491 149 | 50008 150 | 50498 151 | 50037 152 | 50205 153 | 50661 154 | 51581 155 | 50453 156 | 50695 157 | 51575 158 | 51111 159 | 51323 160 | 51129 161 | 50659 162 | 51324 163 | 51116 164 | 51572 165 | 50692 166 | 50666 167 | 50202 168 | 50030 169 | 51142 170 | 51370 171 | 50269 172 | 51189 173 | 50251 174 | 50407 175 | 50438 176 | 51348 177 | 50603 178 | 50267 179 | 51187 180 | 50055 181 | 51341 182 | 50293 183 | 51173 184 | 51174 185 | 51346 186 | 50294 187 | 51180 188 | 50052 189 | 50260 190 | 50604 191 | 50436 192 | 50658 193 | 51128 194 | 50667 195 | 50455 196 | 50031 197 | 50203 198 | 51117 199 | 51325 200 | 50693 201 | 51573 202 | 50499 203 | 50009 204 | 51574 205 | 50694 206 | 51322 207 | 51110 208 | 50204 209 | 50036 210 | 51580 211 | 50660 212 | 50803 213 | 50669 214 | 51314 215 | 51126 216 | 50490 217 | 50656 218 | 50232 219 | 50038 220 | 50804 221 | 50007 222 | 50235 223 | 50651 224 | 50463 225 | 50497 226 | 51121 227 | 51313 228 | 50261 229 | 51181 230 | 50053 231 | 50437 232 | 50605 233 | 51347 234 | 50295 235 | 51175 236 | 50408 237 | 51172 238 | 51340 239 | 50292 240 | 50602 241 | 51186 242 | 50054 243 | 50266 244 | 50259 245 | 50250 246 | 50406 247 | 51349 248 | 50439 249 | 50257 250 | 51188 251 | 50268 252 | 51195 253 | 50047 254 | 50275 255 | 50611 256 | 51161 257 | 51353 258 | 50281 259 | 51159 260 | 51354 261 | 50286 262 | 51166 263 | 50424 264 | 50616 265 | 50272 266 | 51192 267 | 50040 268 | 50049 269 | 51362 270 | 51150 271 | 50412 272 | 50620 273 | 50618 274 | 50288 275 | 51168 276 | 50627 277 | 50415 278 | 50243 279 | 51365 280 | 50441 281 | 50217 282 | 50025 283 | 50819 284 | 51331 285 | 51103 286 | 51567 287 | 50687 288 | 50826 289 | 51558 290 | 51560 291 | 51104 292 | 51336 293 | 50022 294 | 50210 295 | 50446 296 | 51309 297 | 50821 298 | 51132 299 | 51300 300 | 51556 301 | 50642 302 | 50470 303 | 50014 304 | 51569 305 | 50689 306 | 50817 307 | 50013 308 | 50477 309 | 50645 310 | 50483 311 | 51307 312 | 51135 313 | 50448 314 | 51338 315 | 51169 316 | 50289 317 | 50619 318 | 51364 319 | 51156 320 | 50414 321 | 50626 322 | 50242 323 | 50048 324 | 50245 325 | 50621 326 | 50413 327 | 51151 328 | 51363 329 | 50628 330 | 50617 331 | 50425 332 | 51193 333 | 50041 334 | 50273 335 | 51167 336 | 51355 337 | 50287 338 | 51352 339 | 50280 340 | 51160 341 | 50274 342 | 51194 343 | 50046 344 | 50422 345 | 50610 346 | 50482 347 | 51134 348 | 51306 349 | 50012 350 | 50644 351 | 51339 352 | 50449 353 | 50643 354 | 50015 355 | 51301 356 | 51133 357 | 50485 358 | 51557 359 | 50816 360 | 50688 361 | 51568 362 | 50211 363 | 50023 364 | 50447 365 | 51561 366 | 51105 367 | 50820 368 | 51308 369 | 51102 370 | 51330 371 | 50686 372 | 51566 373 | 50440 374 | 50818 375 | 50024 376 | 50216 377 | 51559 378 | 50169 379 | 50955 380 | 50156 381 | 51084 382 | 50364 383 | 50700 384 | 50532 385 | 51070 386 | 50390 387 | 51048 388 | 50952 389 | 50738 390 | 50397 391 | 51077 392 | 50999 393 | 50707 394 | 50363 395 | 51083 396 | 50990 397 | 50158 398 | 50964 399 | 51273 400 | 51041 401 | 50193 402 | 50355 403 | 50167 404 | 50503 405 | 50731 406 | 50709 407 | 50399 408 | 51079 409 | 50997 410 | 50736 411 | 50504 412 | 50160 413 | 51280 414 | 50352 415 | 51046 416 | 50194 417 | 51274 418 | 51482 419 | 50306 420 | 50134 421 | 51220 422 | 51012 423 | 51476 424 | 50796 425 | 50339 426 | 50791 427 | 51471 428 | 51015 429 | 51227 430 | 50133 431 | 50301 432 | 50557 433 | 51485 434 | 51218 435 | 50568 436 | 51023 437 | 51211 438 | 50753 439 | 50561 440 | 50105 441 | 50337 442 | 51478 443 | 50798 444 | 50308 445 | 50330 446 | 50102 447 | 50566 448 | 50754 449 | 51216 450 | 51024 451 | 50559 452 | 51229 453 | 50996 454 | 51078 455 | 50962 456 | 50708 457 | 51275 458 | 51047 459 | 50195 460 | 50505 461 | 50737 462 | 51281 463 | 50353 464 | 50161 465 | 50965 466 | 50159 467 | 50991 468 | 50166 469 | 50354 470 | 50730 471 | 50502 472 | 51040 473 | 50192 474 | 51272 475 | 50739 476 | 51049 477 | 50706 478 | 50150 479 | 51082 480 | 50362 481 | 50998 482 | 51076 483 | 50954 484 | 50168 485 | 50391 486 | 51071 487 | 50365 488 | 50157 489 | 51085 490 | 50701 491 | 51025 492 | 51217 493 | 50103 494 | 50331 495 | 50755 496 | 50567 497 | 51228 498 | 50558 499 | 50560 500 | 50752 501 | 50336 502 | 50104 503 | 51210 504 | 50799 505 | 51479 506 | 50300 507 | 50132 508 | 50556 509 | 51484 510 | 51470 511 | 50790 512 | 51226 513 | 51014 514 | 50569 515 | 51219 516 | 51013 517 | 51221 518 | 50797 519 | 51477 520 | 50551 521 | 51483 522 | 50135 523 | 50307 524 | 50338 525 | 50171 526 | 50343 527 | 51291 528 | 50727 529 | 50515 530 | 50185 531 | 51057 532 | 51265 533 | 50972 534 | 50388 535 | 50986 536 | 51068 537 | 51262 538 | 50182 539 | 51050 540 | 51606 541 | 50344 542 | 51296 543 | 50981 544 | 50149 545 | 51254 546 | 50386 547 | 50988 548 | 51066 549 | 50372 550 | 50524 551 | 51059 552 | 50711 553 | 50523 554 | 51095 555 | 50147 556 | 50375 557 | 51061 558 | 51253 559 | 50381 560 | 51298 561 | 51238 562 | 50577 563 | 50745 564 | 50321 565 | 50113 566 | 51207 567 | 51035 568 | 51469 569 | 50789 570 | 50319 571 | 51456 572 | 51032 573 | 50114 574 | 50326 575 | 50742 576 | 50570 577 | 51209 578 | 51236 579 | 50780 580 | 51460 581 | 50774 582 | 50122 583 | 50310 584 | 51458 585 | 50317 586 | 50125 587 | 51493 588 | 50773 589 | 51467 590 | 50787 591 | 51231 592 | 51003 593 | 51252 594 | 50380 595 | 51060 596 | 50710 597 | 50374 598 | 51094 599 | 50146 600 | 51299 601 | 51093 602 | 50373 603 | 50525 604 | 51067 605 | 50989 606 | 51255 607 | 50387 608 | 50728 609 | 51058 610 | 50345 611 | 51297 612 | 50183 613 | 51051 614 | 51263 615 | 51607 616 | 50148 617 | 50974 618 | 51264 619 | 50184 620 | 51056 621 | 50342 622 | 50170 623 | 50514 624 | 50726 625 | 51069 626 | 50987 627 | 50973 628 | 51459 629 | 50329 630 | 50786 631 | 51466 632 | 51002 633 | 51230 634 | 50124 635 | 50316 636 | 50772 637 | 51492 638 | 50578 639 | 51208 640 | 50775 641 | 50311 642 | 50123 643 | 51237 644 | 51461 645 | 50781 646 | 50318 647 | 50788 648 | 51468 649 | 50327 650 | 50115 651 | 50571 652 | 50743 653 | 51457 654 | 51201 655 | 51033 656 | 51239 657 | 51034 658 | 51206 659 | 50744 660 | 50576 661 | 50112 662 | 50320 663 | 50060 664 | 50252 665 | 50404 666 | 51146 667 | 50609 668 | 50299 669 | 51179 670 | 51373 671 | 51141 672 | 50403 673 | 50255 674 | 50058 675 | 50297 676 | 51345 677 | 51177 678 | 50263 679 | 50051 680 | 51183 681 | 50435 682 | 50607 683 | 51148 684 | 50056 685 | 51184 686 | 50264 687 | 51170 688 | 50290 689 | 51342 690 | 50801 691 | 51329 692 | 50466 693 | 50654 694 | 51316 695 | 51124 696 | 50492 697 | 51578 698 | 50698 699 | 50208 700 | 51123 701 | 51311 702 | 50005 703 | 50237 704 | 50653 705 | 51318 706 | 50468 707 | 51327 708 | 50691 709 | 51571 710 | 50665 711 | 51585 712 | 50033 713 | 50201 714 | 50239 715 | 50206 716 | 50034 717 | 51582 718 | 51576 719 | 50696 720 | 51320 721 | 51112 722 | 50291 723 | 51343 724 | 51171 725 | 50433 726 | 50601 727 | 50265 728 | 50057 729 | 51185 730 | 50050 731 | 51182 732 | 50262 733 | 50606 734 | 50434 735 | 50296 736 | 51344 737 | 51149 738 | 50402 739 | 50254 740 | 51140 741 | 50059 742 | 51147 743 | 50253 744 | 50405 745 | 51178 746 | 50298 747 | 50608 748 | 50697 749 | 51577 750 | 51113 751 | 51321 752 | 50035 753 | 50207 754 | 50663 755 | 51583 756 | 50469 757 | 51319 758 | 51584 759 | 50664 760 | 50200 761 | 50032 762 | 51326 763 | 51114 764 | 51570 765 | 50690 766 | 50807 767 | 50209 768 | 50699 769 | 51579 770 | 50236 771 | 50004 772 | 50652 773 | 50494 774 | 51122 775 | 51328 776 | 50800 777 | 51317 778 | 50493 779 | 50655 780 | 50467 781 | 50003 782 | 51563 783 | 50683 784 | 51335 785 | 51107 786 | 50213 787 | 50445 788 | 51138 789 | 50648 790 | 50822 791 | 50442 792 | 50026 793 | 50214 794 | 51100 795 | 51332 796 | 51564 797 | 50019 798 | 50825 799 | 50489 800 | 50010 801 | 50646 802 | 50480 803 | 51136 804 | 51304 805 | 51109 806 | 51303 807 | 51131 808 | 50487 809 | 50017 810 | 50028 811 | 50814 812 | 50418 813 | 51165 814 | 50285 815 | 51357 816 | 50615 817 | 50427 818 | 50043 819 | 51191 820 | 50271 821 | 50249 822 | 50276 823 | 50044 824 | 51196 825 | 50612 826 | 50282 827 | 51350 828 | 51162 829 | 51359 830 | 50416 831 | 50624 832 | 50240 833 | 51154 834 | 50278 835 | 51198 836 | 51153 837 | 51361 838 | 50247 839 | 50623 840 | 50411 841 | 50016 842 | 51130 843 | 51302 844 | 50486 845 | 50815 846 | 50029 847 | 50481 848 | 51305 849 | 51137 850 | 50011 851 | 50647 852 | 50812 853 | 51333 854 | 51101 855 | 51565 856 | 50685 857 | 50443 858 | 50215 859 | 50027 860 | 50488 861 | 50824 862 | 50020 863 | 50212 864 | 50444 865 | 50682 866 | 51562 867 | 51106 868 | 51334 869 | 50649 870 | 50823 871 | 51139 872 | 51199 873 | 50279 874 | 50246 875 | 50410 876 | 50622 877 | 51360 878 | 51152 879 | 51358 880 | 50428 881 | 51155 882 | 50625 883 | 50417 884 | 50241 885 | 50248 886 | 51163 887 | 50283 888 | 51351 889 | 50045 890 | 51197 891 | 50277 892 | 50613 893 | 50421 894 | 50419 895 | 51369 896 | 50426 897 | 50614 898 | 50270 899 | 50042 900 | 51190 901 | 50284 902 | 51356 903 | 51164 904 | 51472 905 | 50792 906 | 51224 907 | 51016 908 | 50302 909 | 50130 910 | 51486 911 | 50554 912 | 51029 913 | 51481 914 | 50553 915 | 50305 916 | 51011 917 | 51223 918 | 50795 919 | 50333 920 | 50757 921 | 50565 922 | 51027 923 | 51215 924 | 51488 925 | 51018 926 | 51212 927 | 51020 928 | 50562 929 | 50750 930 | 50334 931 | 50106 932 | 51279 933 | 50199 934 | 50509 935 | 51074 936 | 50704 937 | 51080 938 | 50152 939 | 50360 940 | 50956 941 | 50358 942 | 50367 943 | 51087 944 | 50969 945 | 50531 946 | 50703 947 | 51241 948 | 51073 949 | 50960 950 | 50994 951 | 51248 952 | 50507 953 | 50735 954 | 50351 955 | 50163 956 | 51277 957 | 50197 958 | 51045 959 | 50993 960 | 50369 961 | 51089 962 | 50967 963 | 50190 964 | 51042 965 | 50164 966 | 50958 967 | 50356 968 | 50732 969 | 50500 970 | 50751 971 | 50563 972 | 50107 973 | 50335 974 | 51021 975 | 51213 976 | 51214 977 | 51026 978 | 50332 979 | 50564 980 | 50756 981 | 51019 982 | 51489 983 | 51222 984 | 51010 985 | 51474 986 | 50794 987 | 51480 988 | 50552 989 | 50304 990 | 50136 991 | 50109 992 | 50131 993 | 50303 994 | 51487 995 | 50555 996 | 50793 997 | 51473 998 | 51017 999 | 51225 1000 | 51028 1001 | 50966 1002 | 51088 1003 | 50368 1004 | 50992 1005 | 50357 1006 | 50959 1007 | 50501 1008 | 50733 1009 | 51271 1010 | 50191 1011 | 51249 1012 | 50995 1013 | 50961 1014 | 50196 1015 | 51044 1016 | 51276 1017 | 50162 1018 | 50350 1019 | 51282 1020 | 50359 1021 | 50957 1022 | 51072 1023 | 51240 1024 | 50968 1025 | 51086 1026 | 50366 1027 | 50702 1028 | 50530 1029 | 50198 1030 | 51278 1031 | 50705 1032 | 50361 1033 | 51081 1034 | 50153 1035 | 51075 1036 | -------------------------------------------------------------------------------- /util/analysis/extract_info_from_log.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | import numpy as np 4 | def main(args): 5 | table = [] 6 | with open(args.path, 'r') as f: 7 | lines = f.readlines() 8 | 9 | for l in lines: 10 | value = re.findall(r'.*Epoch\[(\d+)/500\].*Train Loss: (\d+\.\d+).*Test Loss: (\d+\.\d+)', l) 11 | table.append(value[0]) 12 | 13 | s = f'|Epoch|' 14 | for i in range(0, 500, 50): 15 | s += f'{i}|' 16 | print(s) 17 | 18 | for j, name in enumerate(['Train Loss', "Test Loss"]): 19 | s = f'|{name}|' 20 | for i in range(0, 500, 50): 21 | s += f'{table[i][j+1]}|' 22 | print(s) 23 | 24 | # data = np.load(args.path, allow_pickle=True).item() 25 | # final_fc = data["timeseires"] 26 | # final_pearson = data["corr"] 27 | # labels = data["label"] 28 | 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--path', default='/home/star/CodeGoat24/FBNETGEN/result/ABIDE_AAL_71.43%/learnable_matrix.npy', type=str, 34 | help='Log file path.') 35 | args = parser.parse_args() 36 | main(args) -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class Logger: 5 | def __init__(self): 6 | self.logger = logging.getLogger() 7 | self.logger.setLevel(logging.INFO) 8 | for handler in self.logger.handlers: 9 | handler.close() 10 | self.logger.handlers.clear() 11 | 12 | formatter = logging.Formatter( 13 | '[%(asctime)s][%(filename)s][L%(lineno)d][%(levelname)s] %(message)s') 14 | stream_handler = logging.StreamHandler() 15 | stream_handler.setFormatter(formatter) 16 | self.logger.addHandler(stream_handler) 17 | 18 | def info(self, info: str): 19 | self.logger.info(info) 20 | -------------------------------------------------------------------------------- /util/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | EPS = 1e-10 3 | 4 | def topk_loss(s,ratio): 5 | # if ratio > 0.5: 6 | # ratio = 1-ratio 7 | s = s.sort(dim=1).values 8 | 9 | # graph transformer 在此sigmoid 10 | # s = torch.sigmoid(s) 11 | 12 | res = -torch.log(s[:,-int(s.size(1)*ratio):]+EPS).mean() -torch.log(1-s[:,:-int(s.size(1)*ratio)]+EPS).mean() 13 | return res 14 | 15 | def inner_loss(label, matrixs): 16 | 17 | loss = 0 18 | 19 | if torch.sum(label == 0) > 1: 20 | loss += torch.mean(torch.var(matrixs[label == 0], dim=0)) 21 | 22 | if torch.sum(label == 1) > 1: 23 | loss += torch.mean(torch.var(matrixs[label == 1], dim=0)) 24 | 25 | return loss 26 | 27 | 28 | def intra_loss(label, matrixs): 29 | a, b = None, None 30 | 31 | if torch.sum(label == 0) > 0: 32 | a = torch.mean(matrixs[label == 0], dim=0) 33 | 34 | if torch.sum(label == 1) > 0: 35 | b = torch.mean(matrixs[label == 1], dim=0) 36 | if a is not None and b is not None: 37 | return 1 - torch.mean(torch.pow(a-b, 2)) 38 | else: 39 | return 0 40 | 41 | 42 | def mixup_cluster_loss(matrixs, y_a, y_b, lam, intra_weight=2): 43 | 44 | y_1 = lam * y_a.float() + (1 - lam) * y_b.float() 45 | 46 | y_0 = 1 - y_1 47 | 48 | bz, roi_num, _ = matrixs.shape 49 | matrixs = matrixs.reshape((bz, -1)) 50 | sum_1 = torch.sum(y_1) 51 | sum_0 = torch.sum(y_0) 52 | loss = 0.0 53 | 54 | if sum_0 > 0: 55 | center_0 = torch.matmul(y_0, matrixs)/sum_0 56 | diff_0 = torch.norm(matrixs-center_0, p=1, dim=1) 57 | loss += torch.matmul(y_0, diff_0)/(sum_0*roi_num*roi_num) 58 | if sum_1 > 0: 59 | center_1 = torch.matmul(y_1, matrixs)/sum_1 60 | diff_1 = torch.norm(matrixs-center_1, p=1, dim=1) 61 | loss += torch.matmul(y_1, diff_1)/(sum_1*roi_num*roi_num) 62 | if sum_0 > 0 and sum_1 > 0: 63 | loss += intra_weight * \ 64 | (1 - torch.norm(center_0-center_1, p=1)/(roi_num*roi_num)) 65 | 66 | return loss 67 | -------------------------------------------------------------------------------- /util/meter.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | 4 | 5 | def accuracy(output: torch.Tensor, target: torch.Tensor, top_k=(1,)) -> List[float]: 6 | max_k = max(top_k) 7 | batch_size = target.size(0) 8 | 9 | _, predict = output.topk(max_k, 1, True, True) 10 | predict = predict.t() 11 | correct = predict.eq(target.view(1, -1).expand_as(predict)) 12 | 13 | res = [] 14 | for k in top_k: 15 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 16 | res.append(correct_k.mul_(100.0 / batch_size).item()) 17 | return res 18 | 19 | 20 | class AverageMeter: 21 | def __init__(self, length: int, name: str = None): 22 | assert length > 0 23 | self.name = name 24 | self.count = 0 25 | self.sum = 0.0 26 | self.current: int = -1 27 | self.history: List[float] = [None] * length 28 | 29 | @property 30 | def val(self) -> float: 31 | return self.history[self.current] 32 | 33 | @property 34 | def avg(self) -> float: 35 | return self.sum / self.count 36 | 37 | def update(self, val: float): 38 | self.current = (self.current + 1) % len(self.history) 39 | self.sum += val 40 | 41 | old = self.history[self.current] 42 | if old is None: 43 | self.count += 1 44 | else: 45 | self.sum -= old 46 | self.history[self.current] = val 47 | 48 | 49 | class TotalMeter: 50 | def __init__(self): 51 | self.sum = 0.0 52 | self.count = 0 53 | 54 | def update(self, val: float): 55 | self.sum += val 56 | self.count += 1 57 | 58 | def update_with_weight(self, val: float, count: int): 59 | self.sum += val*count 60 | self.count += count 61 | 62 | def reset(self): 63 | self.sum = 0 64 | self.count = 0 65 | 66 | @property 67 | def avg(self): 68 | if self.count == 0: 69 | return -1 70 | return self.sum / self.count 71 | -------------------------------------------------------------------------------- /util/prepossess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | 6 | def mixup_data(x, nodes, y, alpha=1.0, device='cuda'): 7 | '''Returns mixed inputs, pairs of targets, and lambda''' 8 | if alpha > 0: 9 | lam = np.random.beta(alpha, alpha) 10 | else: 11 | lam = 1 12 | 13 | batch_size = x.size()[0] 14 | index = torch.randperm(batch_size).to(device) 15 | 16 | mixed_nodes = lam * nodes + (1 - lam) * nodes[index, :] 17 | mixed_x = lam * x + (1 - lam) * x[index, :] 18 | y_a, y_b = y, y[index] 19 | return mixed_x, mixed_nodes, y_a, y_b, lam 20 | 21 | 22 | def mixup_data_by_class(x, nodes, y, alpha=1.0, device='cuda'): 23 | '''Returns mixed inputs, pairs of targets, and lambda''' 24 | 25 | mix_xs, mix_nodes, mix_ys = [], [], [] 26 | 27 | for t_y in y.unique(): 28 | idx = y == t_y 29 | 30 | t_mixed_x, t_mixed_nodes, _, _, _ = mixup_data( 31 | x[idx], nodes[idx], y[idx], alpha=alpha, device=device) 32 | mix_xs.append(t_mixed_x) 33 | mix_nodes.append(t_mixed_nodes) 34 | 35 | mix_ys.append(y[idx]) 36 | 37 | return torch.cat(mix_xs, dim=0), torch.cat(mix_nodes, dim=0), torch.cat(mix_ys, dim=0) 38 | 39 | 40 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 41 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 42 | 43 | 44 | def cal_step_connect(connectity, step): 45 | multi_step = connectity 46 | for _ in range(step): 47 | multi_step = np.dot(multi_step, connectity) 48 | multi_step[multi_step > 0] = 1 49 | return multi_step 50 | 51 | 52 | def obtain_partition(dataloader, fc_threshold, step=2): 53 | pearsons = [] 54 | for data_in, pearson, label in dataloader: 55 | pearsons.append(pearson) 56 | 57 | fc_data = torch.mean(torch.cat(pearsons), dim=0) 58 | 59 | fc_data[fc_data > fc_threshold] = 1 60 | fc_data[fc_data <= fc_threshold] = 0 61 | 62 | _, n = fc_data.shape 63 | 64 | final_partition = torch.zeros((n, (n-1)*n//2)) 65 | 66 | connection = cal_step_connect(fc_data, step) 67 | temp = 0 68 | for i in range(connection.shape[0]): 69 | temp += i 70 | for j in range(i): 71 | if connection[i, j] > 0: 72 | final_partition[i, temp-i+j] = 1 73 | final_partition[j, temp-i+j] = 1 74 | # a = random.randint(0, n-1) 75 | # b = random.randint(0, n-1) 76 | # final_partition[a, temp-i+j] = 1 77 | # final_partition[b, temp-i+j] = 1 78 | 79 | connect_num = torch.sum(final_partition > 0)/n 80 | print(f'Final Partition {connect_num}') 81 | 82 | return final_partition.cuda().float(), connect_num 83 | --------------------------------------------------------------------------------