├── imports ├── __inits__.py ├── __pycache__ │ ├── preprocess_data.cpython-36.pyc │ └── preprocess_data.cpython-38.pyc ├── utils.py ├── ABIDEDataset.py ├── read_abide_stats_parall.py ├── preprocess_data.py └── gdc.py ├── .idea ├── encodings.xml ├── .gitignore ├── modules.xml ├── misc.xml ├── GNN_biomarker_MEDIA.iml ├── webServers.xml ├── deployment.xml └── inspectionProfiles │ └── Project_Default.xml ├── net ├── inits.py ├── braingraphconv.py ├── braingnn.py └── brainmsgpassing.py ├── README.md ├── 02-process_data.py ├── 01-fetch_data.py ├── 03-main.py ├── data └── subject_ID.txt └── requirements.txt /imports/__inits__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imports/__pycache__/preprocess_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlya/BrainGNN_Pytorch/HEAD/imports/__pycache__/preprocess_data.cpython-36.pyc -------------------------------------------------------------------------------- /imports/__pycache__/preprocess_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlya/BrainGNN_Pytorch/HEAD/imports/__pycache__/preprocess_data.cpython-38.pyc -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /.idea/GNN_biomarker_MEDIA.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | -------------------------------------------------------------------------------- /net/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def uniform(size, tensor): 5 | bound = 1.0 / math.sqrt(size) 6 | if tensor is not None: 7 | tensor.data.uniform_(-bound, bound) 8 | 9 | 10 | def kaiming_uniform(tensor, fan, a): 11 | if tensor is not None: 12 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 13 | tensor.data.uniform_(-bound, bound) 14 | 15 | 16 | def glorot(tensor): 17 | if tensor is not None: 18 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 19 | tensor.data.uniform_(-stdv, stdv) 20 | 21 | 22 | def zeros(tensor): 23 | if tensor is not None: 24 | tensor.data.fill_(0) 25 | 26 | 27 | def ones(tensor): 28 | if tensor is not None: 29 | tensor.data.fill_(1) -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /imports/utils.py: -------------------------------------------------------------------------------- 1 | from scipy import stats 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | from scipy.io import loadmat 6 | from sklearn.model_selection import StratifiedKFold 7 | from sklearn.model_selection import KFold 8 | 9 | 10 | def train_val_test_split(kfold = 5, fold = 0): 11 | n_sub = 1035 12 | id = list(range(n_sub)) 13 | 14 | 15 | import random 16 | random.seed(123) 17 | random.shuffle(id) 18 | 19 | kf = KFold(n_splits=kfold, random_state=123,shuffle = True) 20 | kf2 = KFold(n_splits=kfold-1, shuffle=True, random_state = 666) 21 | 22 | 23 | test_index = list() 24 | train_index = list() 25 | val_index = list() 26 | 27 | for tr,te in kf.split(np.array(id)): 28 | test_index.append(te) 29 | tr_id, val_id = list(kf2.split(tr))[0] 30 | train_index.append(tr[tr_id]) 31 | val_index.append(tr[val_id]) 32 | 33 | train_id = train_index[fold] 34 | test_id = test_index[fold] 35 | val_id = val_index[fold] 36 | 37 | return train_id,val_id,test_id -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Neural Network for Brain Network Analysis 2 | A preliminary implementation of BrainGNN. The example presented here is on the public resting-state fMRI ABIDE for the convenience of development. This dataset was different from the ones used in our publication, which are cleaner task-fMRI. Still seeking solutions improve representation learning on the noisy data. 3 | 4 | 5 | ## Usage 6 | ### Setup 7 | **pip** 8 | 9 | See the `requirements.txt` for environment configuration. 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | **PYG** 14 | 15 | To install pyg library, [please refer to the document](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) 16 | 17 | ### Dataset 18 | **ABIDE** 19 | 20 | We treat each fMRI as a brain graph. How to download and construct the graphs? 21 | ``` 22 | python 01-fetch_data.py 23 | python 02-process_data.py 24 | ``` 25 | 26 | ### How to run classification? 27 | Training and testing are integrated in file `main.py`. To run 28 | ``` 29 | python 03-main.py 30 | ``` 31 | 32 | 33 | ## Citation 34 | If you find the code and dataset useful, please cite our paper. 35 | ```latex 36 | @article{li2020braingnn, 37 | title={Braingnn: Interpretable brain graph neural network for fmri analysis}, 38 | author={Li, Xiaoxiao and Zhou,Yuan and Dvornek, Nicha and Zhang, Muhan and Gao, Siyuan and Zhuang, Juntang and Scheinost, Dustin and Staib, Lawrence and Ventola, Pamela and Duncan, James}, 39 | journal={bioRxiv}, 40 | year={2020}, 41 | publisher={Cold Spring Harbor Laboratory} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /imports/ABIDEDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import InMemoryDataset,Data 3 | from os.path import join, isfile 4 | from os import listdir 5 | import numpy as np 6 | import os.path as osp 7 | from imports.read_abide_stats_parall import read_data 8 | 9 | 10 | class ABIDEDataset(InMemoryDataset): 11 | def __init__(self, root, name, transform=None, pre_transform=None): 12 | self.root = root 13 | self.name = name 14 | super(ABIDEDataset, self).__init__(root,transform, pre_transform) 15 | self.data, self.slices = torch.load(self.processed_paths[0]) 16 | 17 | @property 18 | def raw_file_names(self): 19 | data_dir = osp.join(self.root,'raw') 20 | onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))] 21 | onlyfiles.sort() 22 | return onlyfiles 23 | @property 24 | def processed_file_names(self): 25 | return 'data.pt' 26 | 27 | def download(self): 28 | # Download to `self.raw_dir`. 29 | return 30 | 31 | def process(self): 32 | # Read data into huge `Data` list. 33 | self.data, self.slices = read_data(self.raw_dir) 34 | 35 | if self.pre_filter is not None: 36 | data_list = [self.get(idx) for idx in range(len(self))] 37 | data_list = [data for data in data_list if self.pre_filter(data)] 38 | self.data, self.slices = self.collate(data_list) 39 | 40 | if self.pre_transform is not None: 41 | data_list = [self.get(idx) for idx in range(len(self))] 42 | data_list = [self.pre_transform(data) for data in data_list] 43 | self.data, self.slices = self.collate(data_list) 44 | 45 | torch.save((self.data, self.slices), self.processed_paths[0]) 46 | 47 | def __repr__(self): 48 | return '{}({})'.format(self.name, len(self)) 49 | -------------------------------------------------------------------------------- /net/braingraphconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Parameter 4 | from net.brainmsgpassing import MyMessagePassing 5 | from torch_geometric.utils import add_remaining_self_loops,softmax 6 | 7 | from torch_geometric.typing import (OptTensor) 8 | 9 | from net.inits import uniform 10 | 11 | 12 | class MyNNConv(MyMessagePassing): 13 | def __init__(self, in_channels, out_channels, nn, normalize=False, bias=True, 14 | **kwargs): 15 | super(MyNNConv, self).__init__(aggr='mean', **kwargs) 16 | 17 | self.in_channels = in_channels 18 | self.out_channels = out_channels 19 | self.normalize = normalize 20 | self.nn = nn 21 | #self.weight = Parameter(torch.Tensor(self.in_channels, out_channels)) 22 | 23 | if bias: 24 | self.bias = Parameter(torch.Tensor(out_channels)) 25 | else: 26 | self.register_parameter('bias', None) 27 | 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | # uniform(self.in_channels, self.weight) 32 | uniform(self.in_channels, self.bias) 33 | 34 | def forward(self, x, edge_index, edge_weight=None, pseudo= None, size=None): 35 | """""" 36 | edge_weight = edge_weight.squeeze() 37 | if size is None and torch.is_tensor(x): 38 | edge_index, edge_weight = add_remaining_self_loops( 39 | edge_index, edge_weight, 1, x.size(0)) 40 | 41 | weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels) 42 | if torch.is_tensor(x): 43 | x = torch.matmul(x.unsqueeze(1), weight).squeeze(1) 44 | else: 45 | x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1), 46 | None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1)) 47 | 48 | # weight = self.nn(pseudo).view(-1, self.out_channels,self.in_channels) 49 | # if torch.is_tensor(x): 50 | # x = torch.matmul(x.unsqueeze(1), weight.permute(0,2,1)).squeeze(1) 51 | # else: 52 | # x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1), 53 | # None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1)) 54 | 55 | return self.propagate(edge_index, size=size, x=x, 56 | edge_weight=edge_weight) 57 | 58 | def message(self, edge_index_i, size_i, x_j, edge_weight, ptr: OptTensor): 59 | edge_weight = softmax(edge_weight, edge_index_i, ptr, size_i) 60 | return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j 61 | 62 | def update(self, aggr_out): 63 | if self.bias is not None: 64 | aggr_out = aggr_out + self.bias 65 | if self.normalize: 66 | aggr_out = F.normalize(aggr_out, p=2, dim=-1) 67 | return aggr_out 68 | 69 | def __repr__(self): 70 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 71 | self.out_channels) 72 | 73 | -------------------------------------------------------------------------------- /02-process_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Mwiza Kunda 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | 17 | import sys 18 | import argparse 19 | import pandas as pd 20 | import numpy as np 21 | from imports import preprocess_data as Reader 22 | import deepdish as dd 23 | import warnings 24 | import os 25 | 26 | warnings.filterwarnings("ignore") 27 | root_folder = '/data/' 28 | data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/') 29 | 30 | # Process boolean command line arguments 31 | def str2bool(v): 32 | if isinstance(v, bool): 33 | return v 34 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 35 | return True 36 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 37 | return False 38 | else: 39 | raise argparse.ArgumentTypeError('Boolean value expected.') 40 | 41 | 42 | def main(): 43 | parser = argparse.ArgumentParser(description='Classification of the ABIDE dataset using a Ridge classifier. ' 44 | 'MIDA is used to minimize the distribution mismatch between ABIDE sites') 45 | parser.add_argument('--atlas', default='cc200', 46 | help='Atlas for network construction (node definition) options: ho, cc200, cc400, default: cc200.') 47 | parser.add_argument('--seed', default=123, type=int, help='Seed for random initialisation. default: 1234.') 48 | parser.add_argument('--nclass', default=2, type=int, help='Number of classes. default:2') 49 | 50 | 51 | args = parser.parse_args() 52 | print('Arguments: \n', args) 53 | 54 | 55 | params = dict() 56 | 57 | params['seed'] = args.seed # seed for random initialisation 58 | 59 | # Algorithm choice 60 | params['atlas'] = args.atlas # Atlas for network construction 61 | atlas = args.atlas # Atlas for network construction (node definition) 62 | 63 | # Get subject IDs and class labels 64 | subject_IDs = Reader.get_ids() 65 | labels = Reader.get_subject_score(subject_IDs, score='DX_GROUP') 66 | 67 | # Number of subjects and classes for binary classification 68 | num_classes = args.nclass 69 | num_subjects = len(subject_IDs) 70 | params['n_subjects'] = num_subjects 71 | 72 | # Initialise variables for class labels and acquisition sites 73 | # 1 is autism, 2 is control 74 | y_data = np.zeros([num_subjects, num_classes]) # n x 2 75 | y = np.zeros([num_subjects, 1]) # n x 1 76 | 77 | # Get class labels for all subjects 78 | for i in range(num_subjects): 79 | y_data[i, int(labels[subject_IDs[i]]) - 1] = 1 80 | y[i] = int(labels[subject_IDs[i]]) 81 | 82 | # Compute feature vectors (vectorised connectivity networks) 83 | fea_corr = Reader.get_networks(subject_IDs, iter_no='', kind='correlation', atlas_name=atlas) #(1035, 200, 200) 84 | fea_pcorr = Reader.get_networks(subject_IDs, iter_no='', kind='partial correlation', atlas_name=atlas) #(1035, 200, 200) 85 | 86 | if not os.path.exists(os.path.join(data_folder,'raw')): 87 | os.makedirs(os.path.join(data_folder,'raw')) 88 | for i, subject in enumerate(subject_IDs): 89 | dd.io.save(os.path.join(data_folder,'raw',subject+'.h5'),{'corr':fea_corr[i],'pcorr':fea_pcorr[i],'label':y[i]%2}) 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /net/braingnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch_geometric.nn import TopKPooling 5 | from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp 6 | from torch_geometric.utils import (add_self_loops, sort_edge_index, 7 | remove_self_loops) 8 | from torch_sparse import spspmm 9 | 10 | from net.braingraphconv import MyNNConv 11 | 12 | 13 | ########################################################################################################################## 14 | class Network(torch.nn.Module): 15 | def __init__(self, indim, ratio, nclass, k=8, R=200): 16 | ''' 17 | 18 | :param indim: (int) node feature dimension 19 | :param ratio: (float) pooling ratio in (0,1) 20 | :param nclass: (int) number of classes 21 | :param k: (int) number of communities 22 | :param R: (int) number of ROIs 23 | ''' 24 | super(Network, self).__init__() 25 | 26 | self.indim = indim 27 | self.dim1 = 32 28 | self.dim2 = 32 29 | self.dim3 = 512 30 | self.dim4 = 256 31 | self.dim5 = 8 32 | self.k = k 33 | self.R = R 34 | 35 | self.n1 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim1 * self.indim)) 36 | self.conv1 = MyNNConv(self.indim, self.dim1, self.n1, normalize=False) 37 | self.pool1 = TopKPooling(self.dim1, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid) 38 | self.n2 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim2 * self.dim1)) 39 | self.conv2 = MyNNConv(self.dim1, self.dim2, self.n2, normalize=False) 40 | self.pool2 = TopKPooling(self.dim2, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid) 41 | 42 | #self.fc1 = torch.nn.Linear((self.dim2) * 2, self.dim2) 43 | self.fc1 = torch.nn.Linear((self.dim1+self.dim2)*2, self.dim2) 44 | self.bn1 = torch.nn.BatchNorm1d(self.dim2) 45 | self.fc2 = torch.nn.Linear(self.dim2, self.dim3) 46 | self.bn2 = torch.nn.BatchNorm1d(self.dim3) 47 | self.fc3 = torch.nn.Linear(self.dim3, nclass) 48 | 49 | 50 | 51 | 52 | def forward(self, x, edge_index, batch, edge_attr, pos): 53 | 54 | x = self.conv1(x, edge_index, edge_attr, pos) 55 | x, edge_index, edge_attr, batch, perm, score1 = self.pool1(x, edge_index, edge_attr, batch) 56 | 57 | pos = pos[perm] 58 | x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 59 | 60 | edge_attr = edge_attr.squeeze() 61 | edge_index, edge_attr = self.augment_adj(edge_index, edge_attr, x.size(0)) 62 | 63 | x = self.conv2(x, edge_index, edge_attr, pos) 64 | x, edge_index, edge_attr, batch, perm, score2 = self.pool2(x, edge_index,edge_attr, batch) 65 | 66 | x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 67 | 68 | x = torch.cat([x1,x2], dim=1) 69 | x = self.bn1(F.relu(self.fc1(x))) 70 | x = F.dropout(x, p=0.5, training=self.training) 71 | x = self.bn2(F.relu(self.fc2(x))) 72 | x= F.dropout(x, p=0.5, training=self.training) 73 | x = F.log_softmax(self.fc3(x), dim=-1) 74 | 75 | return x,self.pool1.weight,self.pool2.weight, torch.sigmoid(score1).view(x.size(0),-1), torch.sigmoid(score2).view(x.size(0),-1) 76 | 77 | def augment_adj(self, edge_index, edge_weight, num_nodes): 78 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 79 | num_nodes=num_nodes) 80 | edge_index, edge_weight = sort_edge_index(edge_index, edge_weight, 81 | num_nodes) 82 | edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index, 83 | edge_weight, num_nodes, num_nodes, 84 | num_nodes) 85 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 86 | return edge_index, edge_weight 87 | 88 | -------------------------------------------------------------------------------- /01-fetch_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Mwiza Kunda 2 | # Copyright (C) 2017 Sarah Parisot , , Sofia Ira Ktena 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | ''' 18 | This script mainly refers to https://github.com/kundaMwiza/fMRI-site-adaptation/blob/master/fetch_data.py 19 | ''' 20 | 21 | from nilearn import datasets 22 | import argparse 23 | from imports import preprocess_data as Reader 24 | import os 25 | import shutil 26 | import sys 27 | 28 | # Input data variables 29 | code_folder = os.getcwd() 30 | root_folder = '/data/' 31 | data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/') 32 | if not os.path.exists(data_folder): 33 | os.makedirs(data_folder) 34 | shutil.copyfile(os.path.join(root_folder,'subject_ID.txt'), os.path.join(data_folder, 'subject_IDs.txt')) 35 | 36 | def str2bool(v): 37 | if isinstance(v, bool): 38 | return v 39 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 40 | return True 41 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 42 | return False 43 | else: 44 | raise argparse.ArgumentTypeError('Boolean value expected.') 45 | 46 | 47 | def main(): 48 | parser = argparse.ArgumentParser(description='Download ABIDE data and compute functional connectivity matrices') 49 | parser.add_argument('--pipeline', default='cpac', type=str, 50 | help='Pipeline to preprocess ABIDE data. Available options are ccs, cpac, dparsf and niak.' 51 | ' default: cpac.') 52 | parser.add_argument('--atlas', default='cc200', 53 | help='Brain parcellation atlas. Options: ho, cc200 and cc400, default: cc200.') 54 | parser.add_argument('--download', default=True, type=str2bool, 55 | help='Dowload data or just compute functional connectivity. default: True') 56 | args = parser.parse_args() 57 | print(args) 58 | 59 | params = dict() 60 | 61 | pipeline = args.pipeline 62 | atlas = args.atlas 63 | download = args.download 64 | 65 | # Files to fetch 66 | 67 | files = ['rois_' + atlas] 68 | 69 | filemapping = {'func_preproc': 'func_preproc.nii.gz', 70 | files[0]: files[0] + '.1D'} 71 | 72 | 73 | # Download database files 74 | if download == True: 75 | abide = datasets.fetch_abide_pcp(data_dir=root_folder, pipeline=pipeline, 76 | band_pass_filtering=True, global_signal_regression=False, derivatives=files, 77 | quality_checked=False) 78 | 79 | subject_IDs = Reader.get_ids() #changed path to data path 80 | subject_IDs = subject_IDs.tolist() 81 | 82 | # Create a folder for each subject 83 | for s, fname in zip(subject_IDs, Reader.fetch_filenames(subject_IDs, files[0], atlas)): 84 | subject_folder = os.path.join(data_folder, s) 85 | if not os.path.exists(subject_folder): 86 | os.mkdir(subject_folder) 87 | 88 | # Get the base filename for each subject 89 | base = fname.split(files[0])[0] 90 | 91 | # Move each subject file to the subject folder 92 | for fl in files: 93 | if not os.path.exists(os.path.join(subject_folder, base + filemapping[fl])): 94 | shutil.move(base + filemapping[fl], subject_folder) 95 | 96 | time_series = Reader.get_timeseries(subject_IDs, atlas) 97 | 98 | # Compute and save connectivity matrices 99 | Reader.subject_connectivity(time_series, subject_IDs, atlas, 'correlation') 100 | Reader.subject_connectivity(time_series, subject_IDs, atlas, 'partial correlation') 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /imports/read_abide_stats_parall.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Xiaoxiao Li 3 | Date: 2019/02/24 4 | ''' 5 | 6 | import os.path as osp 7 | from os import listdir 8 | import os 9 | import glob 10 | import h5py 11 | 12 | import torch 13 | import numpy as np 14 | from scipy.io import loadmat 15 | from torch_geometric.data import Data 16 | import networkx as nx 17 | from networkx.convert_matrix import from_numpy_matrix 18 | import multiprocessing 19 | from torch_sparse import coalesce 20 | from torch_geometric.utils import remove_self_loops 21 | from functools import partial 22 | import deepdish as dd 23 | from imports.gdc import GDC 24 | 25 | 26 | def split(data, batch): 27 | node_slice = torch.cumsum(torch.from_numpy(np.bincount(batch)), 0) 28 | node_slice = torch.cat([torch.tensor([0]), node_slice]) 29 | 30 | row, _ = data.edge_index 31 | edge_slice = torch.cumsum(torch.from_numpy(np.bincount(batch[row])), 0) 32 | edge_slice = torch.cat([torch.tensor([0]), edge_slice]) 33 | 34 | # Edge indices should start at zero for every graph. 35 | data.edge_index -= node_slice[batch[row]].unsqueeze(0) 36 | 37 | slices = {'edge_index': edge_slice} 38 | if data.x is not None: 39 | slices['x'] = node_slice 40 | if data.edge_attr is not None: 41 | slices['edge_attr'] = edge_slice 42 | if data.y is not None: 43 | if data.y.size(0) == batch.size(0): 44 | slices['y'] = node_slice 45 | else: 46 | slices['y'] = torch.arange(0, batch[-1] + 2, dtype=torch.long) 47 | if data.pos is not None: 48 | slices['pos'] = node_slice 49 | 50 | return data, slices 51 | 52 | 53 | def cat(seq): 54 | seq = [item for item in seq if item is not None] 55 | seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq] 56 | return torch.cat(seq, dim=-1).squeeze() if len(seq) > 0 else None 57 | 58 | class NoDaemonProcess(multiprocessing.Process): 59 | @property 60 | def daemon(self): 61 | return False 62 | 63 | @daemon.setter 64 | def daemon(self, value): 65 | pass 66 | 67 | 68 | class NoDaemonContext(type(multiprocessing.get_context())): 69 | Process = NoDaemonProcess 70 | 71 | 72 | def read_data(data_dir): 73 | onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))] 74 | onlyfiles.sort() 75 | batch = [] 76 | pseudo = [] 77 | y_list = [] 78 | edge_att_list, edge_index_list,att_list = [], [], [] 79 | 80 | # parallar computing 81 | cores = multiprocessing.cpu_count() 82 | pool = multiprocessing.Pool(processes=cores) 83 | #pool = MyPool(processes = cores) 84 | func = partial(read_sigle_data, data_dir) 85 | 86 | import timeit 87 | 88 | start = timeit.default_timer() 89 | 90 | res = pool.map(func, onlyfiles) 91 | 92 | pool.close() 93 | pool.join() 94 | 95 | stop = timeit.default_timer() 96 | 97 | print('Time: ', stop - start) 98 | 99 | 100 | 101 | for j in range(len(res)): 102 | edge_att_list.append(res[j][0]) 103 | edge_index_list.append(res[j][1]+j*res[j][4]) 104 | att_list.append(res[j][2]) 105 | y_list.append(res[j][3]) 106 | batch.append([j]*res[j][4]) 107 | pseudo.append(np.diag(np.ones(res[j][4]))) 108 | 109 | edge_att_arr = np.concatenate(edge_att_list) 110 | edge_index_arr = np.concatenate(edge_index_list, axis=1) 111 | att_arr = np.concatenate(att_list, axis=0) 112 | pseudo_arr = np.concatenate(pseudo, axis=0) 113 | y_arr = np.stack(y_list) 114 | edge_att_torch = torch.from_numpy(edge_att_arr.reshape(len(edge_att_arr), 1)).float() 115 | att_torch = torch.from_numpy(att_arr).float() 116 | y_torch = torch.from_numpy(y_arr).long() # classification 117 | batch_torch = torch.from_numpy(np.hstack(batch)).long() 118 | edge_index_torch = torch.from_numpy(edge_index_arr).long() 119 | pseudo_torch = torch.from_numpy(pseudo_arr).float() 120 | data = Data(x=att_torch, edge_index=edge_index_torch, y=y_torch, edge_attr=edge_att_torch, pos = pseudo_torch ) 121 | 122 | 123 | data, slices = split(data, batch_torch) 124 | 125 | return data, slices 126 | 127 | 128 | def read_sigle_data(data_dir,filename,use_gdc =False): 129 | 130 | temp = dd.io.load(osp.join(data_dir, filename)) 131 | 132 | # read edge and edge attribute 133 | pcorr = np.abs(temp['pcorr'][()]) 134 | 135 | num_nodes = pcorr.shape[0] 136 | G = from_numpy_matrix(pcorr) 137 | A = nx.to_scipy_sparse_matrix(G) 138 | adj = A.tocoo() 139 | edge_att = np.zeros(len(adj.row)) 140 | for i in range(len(adj.row)): 141 | edge_att[i] = pcorr[adj.row[i], adj.col[i]] 142 | 143 | edge_index = np.stack([adj.row, adj.col]) 144 | edge_index, edge_att = remove_self_loops(torch.from_numpy(edge_index), torch.from_numpy(edge_att)) 145 | edge_index = edge_index.long() 146 | edge_index, edge_att = coalesce(edge_index, edge_att, num_nodes, 147 | num_nodes) 148 | att = temp['corr'][()] 149 | label = temp['label'][()] 150 | 151 | att_torch = torch.from_numpy(att).float() 152 | y_torch = torch.from_numpy(np.array(label)).long() # classification 153 | 154 | data = Data(x=att_torch, edge_index=edge_index.long(), y=y_torch, edge_attr=edge_att) 155 | 156 | if use_gdc: 157 | ''' 158 | Implementation of https://papers.nips.cc/paper/2019/hash/23c894276a2c5a16470e6a31f4618d73-Abstract.html 159 | ''' 160 | data.edge_attr = data.edge_attr.squeeze() 161 | gdc = GDC(self_loop_weight=1, normalization_in='sym', 162 | normalization_out='col', 163 | diffusion_kwargs=dict(method='ppr', alpha=0.2), 164 | sparsification_kwargs=dict(method='topk', k=20, 165 | dim=0), exact=True) 166 | data = gdc(data) 167 | return data.edge_attr.data.numpy(),data.edge_index.data.numpy(),data.x.data.numpy(),data.y.data.item(),num_nodes 168 | 169 | else: 170 | return edge_att.data.numpy(),edge_index.data.numpy(),att,label,num_nodes 171 | 172 | if __name__ == "__main__": 173 | data_dir = '/home/azureuser/projects/BrainGNN/data/ABIDE_pcp/cpac/filt_noglobal/raw' 174 | filename = '50346.h5' 175 | read_sigle_data(data_dir, filename) 176 | 177 | 178 | 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /net/brainmsgpassing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | 4 | import torch 5 | # from torch_geometric.utils import scatter_ 6 | from torch_scatter import scatter,scatter_add 7 | 8 | special_args = [ 9 | 'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j' 10 | ] 11 | __size_error_msg__ = ('All tensors which should get mapped to the same source ' 12 | 'or target nodes must be of same size in dimension 0.') 13 | 14 | is_python2 = sys.version_info[0] < 3 15 | getargspec = inspect.getargspec if is_python2 else inspect.getfullargspec 16 | 17 | 18 | class MyMessagePassing(torch.nn.Module): 19 | r"""Base class for creating message passing layers 20 | .. math:: 21 | \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, 22 | \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} 23 | \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right), 24 | where :math:`\square` denotes a differentiable, permutation invariant 25 | function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}` 26 | and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as 27 | MLPs. 28 | See `here `__ for the accompanying tutorial. 30 | Args: 31 | aggr (string, optional): The aggregation scheme to use 32 | (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`). 33 | (default: :obj:`"add"`) 34 | flow (string, optional): The flow direction of message passing 35 | (:obj:`"source_to_target"` or :obj:`"target_to_source"`). 36 | (default: :obj:`"source_to_target"`) 37 | node_dim (int, optional): The axis along which to propagate. 38 | (default: :obj:`0`) 39 | """ 40 | def __init__(self, aggr='add', flow='source_to_target', node_dim=0): 41 | super(MyMessagePassing, self).__init__() 42 | 43 | self.aggr = aggr 44 | assert self.aggr in ['add', 'mean', 'max'] 45 | 46 | self.flow = flow 47 | assert self.flow in ['source_to_target', 'target_to_source'] 48 | 49 | self.node_dim = node_dim 50 | assert self.node_dim >= 0 51 | 52 | self.__message_args__ = getargspec(self.message)[0][1:] 53 | self.__special_args__ = [(i, arg) 54 | for i, arg in enumerate(self.__message_args__) 55 | if arg in special_args] 56 | self.__message_args__ = [ 57 | arg for arg in self.__message_args__ if arg not in special_args 58 | ] 59 | self.__update_args__ = getargspec(self.update)[0][2:] 60 | 61 | def propagate(self, edge_index, size=None, **kwargs): 62 | r"""The initial call to start propagating messages. 63 | Args: 64 | edge_index (Tensor): The indices of a general (sparse) assignment 65 | matrix with shape :obj:`[N, M]` (can be directed or 66 | undirected). 67 | size (list or tuple, optional): The size :obj:`[N, M]` of the 68 | assignment matrix. If set to :obj:`None`, the size is tried to 69 | get automatically inferred and assumed to be symmetric. 70 | (default: :obj:`None`) 71 | **kwargs: Any additional data which is needed to construct messages 72 | and to update node embeddings. 73 | """ 74 | 75 | dim = self.node_dim 76 | size = [None, None] if size is None else list(size) 77 | assert len(size) == 2 78 | 79 | i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0) 80 | ij = {"_i": i, "_j": j} 81 | 82 | message_args = [] 83 | for arg in self.__message_args__: 84 | if arg[-2:] in ij.keys(): 85 | tmp = kwargs.get(arg[:-2], None) 86 | if tmp is None: # pragma: no cover 87 | message_args.append(tmp) 88 | else: 89 | idx = ij[arg[-2:]] 90 | if isinstance(tmp, tuple) or isinstance(tmp, list): 91 | assert len(tmp) == 2 92 | if tmp[1 - idx] is not None: 93 | if size[1 - idx] is None: 94 | size[1 - idx] = tmp[1 - idx].size(dim) 95 | if size[1 - idx] != tmp[1 - idx].size(dim): 96 | raise ValueError(__size_error_msg__) 97 | tmp = tmp[idx] 98 | 99 | if tmp is None: 100 | message_args.append(tmp) 101 | else: 102 | if size[idx] is None: 103 | size[idx] = tmp.size(dim) 104 | if size[idx] != tmp.size(dim): 105 | raise ValueError(__size_error_msg__) 106 | 107 | tmp = torch.index_select(tmp, dim, edge_index[idx]) 108 | message_args.append(tmp) 109 | else: 110 | message_args.append(kwargs.get(arg, None)) 111 | 112 | size[0] = size[1] if size[0] is None else size[0] 113 | size[1] = size[0] if size[1] is None else size[1] 114 | 115 | kwargs['edge_index'] = edge_index 116 | kwargs['size'] = size 117 | 118 | for (idx, arg) in self.__special_args__: 119 | if arg[-2:] in ij.keys(): 120 | message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]]) 121 | else: 122 | message_args.insert(idx, kwargs[arg]) 123 | 124 | update_args = [kwargs[arg] for arg in self.__update_args__] 125 | 126 | out = self.message(*message_args) 127 | # out = scatter_(self.aggr, out, edge_index[i], dim, dim_size=size[i]) 128 | out = scatter_add(out, edge_index[i], dim, dim_size=size[i]) 129 | out = self.update(out, *update_args) 130 | 131 | return out 132 | 133 | def message(self, x_j): # pragma: no cover 134 | r"""Constructs messages to node :math:`i` in analogy to 135 | :math:`\phi_{\mathbf{\Theta}}` for each edge in 136 | :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and 137 | :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`. 138 | Can take any argument which was initially passed to :meth:`propagate`. 139 | In addition, tensors passed to :meth:`propagate` can be mapped to the 140 | respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or 141 | :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`. 142 | """ 143 | 144 | return x_j 145 | 146 | def update(self, aggr_out): # pragma: no cover 147 | r"""Updates node embeddings in analogy to 148 | :math:`\gamma_{\mathbf{\Theta}}` for each node 149 | :math:`i \in \mathcal{V}`. 150 | Takes in the output of aggregation as first argument and any argument 151 | which was initially passed to :meth:`propagate`.""" 152 | 153 | return aggr_out 154 | -------------------------------------------------------------------------------- /imports/preprocess_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Mwiza Kunda 2 | # Copyright (C) 2017 Sarah Parisot , Sofia Ira Ktena 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implcd ied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | 18 | import os 19 | import warnings 20 | import glob 21 | import csv 22 | import re 23 | import numpy as np 24 | import scipy.io as sio 25 | import sys 26 | from nilearn import connectome 27 | import pandas as pd 28 | from scipy.spatial import distance 29 | from scipy import signal 30 | from sklearn.compose import ColumnTransformer 31 | from sklearn.preprocessing import Normalizer 32 | from sklearn.preprocessing import OrdinalEncoder 33 | from sklearn.preprocessing import OneHotEncoder 34 | from sklearn.preprocessing import StandardScaler 35 | warnings.filterwarnings("ignore") 36 | 37 | # Input data variables 38 | 39 | root_folder = '/home/azureuser/projects/BrainGNN/data/' 40 | data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal') 41 | phenotype = os.path.join(root_folder, 'ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv') 42 | 43 | 44 | def fetch_filenames(subject_IDs, file_type, atlas): 45 | """ 46 | subject_list : list of short subject IDs in string format 47 | file_type : must be one of the available file types 48 | filemapping : resulting file name format 49 | returns: 50 | filenames : list of filetypes (same length as subject_list) 51 | """ 52 | 53 | filemapping = {'func_preproc': '_func_preproc.nii.gz', 54 | 'rois_' + atlas: '_rois_' + atlas + '.1D'} 55 | # The list to be filled 56 | filenames = [] 57 | 58 | # Fill list with requested file paths 59 | for i in range(len(subject_IDs)): 60 | os.chdir(data_folder) 61 | try: 62 | try: 63 | os.chdir(data_folder) 64 | filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0]) 65 | except: 66 | os.chdir(data_folder + '/' + subject_IDs[i]) 67 | filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0]) 68 | except IndexError: 69 | filenames.append('N/A') 70 | return filenames 71 | 72 | 73 | # Get timeseries arrays for list of subjects 74 | def get_timeseries(subject_list, atlas_name, silence=False): 75 | """ 76 | subject_list : list of short subject IDs in string format 77 | atlas_name : the atlas based on which the timeseries are generated e.g. aal, cc200 78 | returns: 79 | time_series : list of timeseries arrays, each of shape (timepoints x regions) 80 | """ 81 | 82 | timeseries = [] 83 | for i in range(len(subject_list)): 84 | subject_folder = os.path.join(data_folder, subject_list[i]) 85 | ro_file = [f for f in os.listdir(subject_folder) if f.endswith('_rois_' + atlas_name + '.1D')] 86 | fl = os.path.join(subject_folder, ro_file[0]) 87 | if silence != True: 88 | print("Reading timeseries file %s" % fl) 89 | timeseries.append(np.loadtxt(fl, skiprows=0)) 90 | 91 | return timeseries 92 | 93 | 94 | # compute connectivity matrices 95 | def subject_connectivity(timeseries, subjects, atlas_name, kind, iter_no='', seed=1234, 96 | n_subjects='', save=True, save_path=data_folder): 97 | """ 98 | timeseries : timeseries table for subject (timepoints x regions) 99 | subjects : subject IDs 100 | atlas_name : name of the parcellation atlas used 101 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation 102 | iter_no : tangent connectivity iteration number for cross validation evaluation 103 | save : save the connectivity matrix to a file 104 | save_path : specify path to save the matrix if different from subject folder 105 | returns: 106 | connectivity : connectivity matrix (regions x regions) 107 | """ 108 | 109 | if kind in ['TPE', 'TE', 'correlation','partial correlation']: 110 | if kind not in ['TPE', 'TE']: 111 | conn_measure = connectome.ConnectivityMeasure(kind=kind) 112 | connectivity = conn_measure.fit_transform(timeseries) 113 | else: 114 | if kind == 'TPE': 115 | conn_measure = connectome.ConnectivityMeasure(kind='correlation') 116 | conn_mat = conn_measure.fit_transform(timeseries) 117 | conn_measure = connectome.ConnectivityMeasure(kind='tangent') 118 | connectivity_fit = conn_measure.fit(conn_mat) 119 | connectivity = connectivity_fit.transform(conn_mat) 120 | else: 121 | conn_measure = connectome.ConnectivityMeasure(kind='tangent') 122 | connectivity_fit = conn_measure.fit(timeseries) 123 | connectivity = connectivity_fit.transform(timeseries) 124 | 125 | if save: 126 | if kind not in ['TPE', 'TE']: 127 | for i, subj_id in enumerate(subjects): 128 | subject_file = os.path.join(save_path, subj_id, 129 | subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '.mat') 130 | sio.savemat(subject_file, {'connectivity': connectivity[i]}) 131 | return connectivity 132 | else: 133 | for i, subj_id in enumerate(subjects): 134 | subject_file = os.path.join(save_path, subj_id, 135 | subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '_' + str( 136 | iter_no) + '_' + str(seed) + '_' + validation_ext + str( 137 | n_subjects) + '.mat') 138 | sio.savemat(subject_file, {'connectivity': connectivity[i]}) 139 | return connectivity_fit 140 | 141 | 142 | # Get the list of subject IDs 143 | 144 | def get_ids(num_subjects=None): 145 | """ 146 | return: 147 | subject_IDs : list of all subject IDs 148 | """ 149 | 150 | subject_IDs = np.genfromtxt(os.path.join(data_folder, 'subject_IDs.txt'), dtype=str) 151 | 152 | if num_subjects is not None: 153 | subject_IDs = subject_IDs[:num_subjects] 154 | 155 | return subject_IDs 156 | 157 | 158 | # Get phenotype values for a list of subjects 159 | def get_subject_score(subject_list, score): 160 | scores_dict = {} 161 | 162 | with open(phenotype) as csv_file: 163 | reader = csv.DictReader(csv_file) 164 | for row in reader: 165 | if row['SUB_ID'] in subject_list: 166 | if score == 'HANDEDNESS_CATEGORY': 167 | if (row[score].strip() == '-9999') or (row[score].strip() == ''): 168 | scores_dict[row['SUB_ID']] = 'R' 169 | elif row[score] == 'Mixed': 170 | scores_dict[row['SUB_ID']] = 'Ambi' 171 | elif row[score] == 'L->R': 172 | scores_dict[row['SUB_ID']] = 'Ambi' 173 | else: 174 | scores_dict[row['SUB_ID']] = row[score] 175 | elif (score == 'FIQ' or score == 'PIQ' or score == 'VIQ'): 176 | if (row[score].strip() == '-9999') or (row[score].strip() == ''): 177 | scores_dict[row['SUB_ID']] = 100 178 | else: 179 | scores_dict[row['SUB_ID']] = float(row[score]) 180 | 181 | else: 182 | scores_dict[row['SUB_ID']] = row[score] 183 | 184 | return scores_dict 185 | 186 | 187 | # preprocess phenotypes. Categorical -> ordinal representation 188 | def preprocess_phenotypes(pheno_ft, params): 189 | if params['model'] == 'MIDA': 190 | ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2])], remainder='passthrough') 191 | else: 192 | ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2, 3])], remainder='passthrough') 193 | 194 | pheno_ft = ct.fit_transform(pheno_ft) 195 | pheno_ft = pheno_ft.astype('float32') 196 | 197 | return (pheno_ft) 198 | 199 | 200 | # create phenotype feature vector to concatenate with fmri feature vectors 201 | def phenotype_ft_vector(pheno_ft, num_subjects, params): 202 | gender = pheno_ft[:, 0] 203 | if params['model'] == 'MIDA': 204 | eye = pheno_ft[:, 0] 205 | hand = pheno_ft[:, 2] 206 | age = pheno_ft[:, 3] 207 | fiq = pheno_ft[:, 4] 208 | else: 209 | eye = pheno_ft[:, 2] 210 | hand = pheno_ft[:, 3] 211 | age = pheno_ft[:, 4] 212 | fiq = pheno_ft[:, 5] 213 | 214 | phenotype_ft = np.zeros((num_subjects, 4)) 215 | phenotype_ft_eye = np.zeros((num_subjects, 2)) 216 | phenotype_ft_hand = np.zeros((num_subjects, 3)) 217 | 218 | for i in range(num_subjects): 219 | phenotype_ft[i, int(gender[i])] = 1 220 | phenotype_ft[i, -2] = age[i] 221 | phenotype_ft[i, -1] = fiq[i] 222 | phenotype_ft_eye[i, int(eye[i])] = 1 223 | phenotype_ft_hand[i, int(hand[i])] = 1 224 | 225 | if params['model'] == 'MIDA': 226 | phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand], axis=1) 227 | else: 228 | phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand, phenotype_ft_eye], axis=1) 229 | 230 | return phenotype_ft 231 | 232 | 233 | # Load precomputed fMRI connectivity networks 234 | def get_networks(subject_list, kind, iter_no='', seed=1234, n_subjects='', atlas_name="aal", 235 | variable='connectivity'): 236 | """ 237 | subject_list : list of subject IDs 238 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation 239 | atlas_name : name of the parcellation atlas used 240 | variable : variable name in the .mat file that has been used to save the precomputed networks 241 | return: 242 | matrix : feature matrix of connectivity networks (num_subjects x network_size) 243 | """ 244 | 245 | all_networks = [] 246 | for subject in subject_list: 247 | if len(kind.split()) == 2: 248 | kind = '_'.join(kind.split()) 249 | fl = os.path.join(data_folder, subject, 250 | subject + "_" + atlas_name + "_" + kind.replace(' ', '_') + ".mat") 251 | 252 | 253 | matrix = sio.loadmat(fl)[variable] 254 | all_networks.append(matrix) 255 | 256 | if kind in ['TE', 'TPE']: 257 | norm_networks = [mat for mat in all_networks] 258 | else: 259 | norm_networks = [np.arctanh(mat) for mat in all_networks] 260 | 261 | networks = np.stack(norm_networks) 262 | 263 | return networks 264 | 265 | -------------------------------------------------------------------------------- /03-main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import time 5 | import copy 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.optim import lr_scheduler 10 | from tensorboardX import SummaryWriter 11 | 12 | from imports.ABIDEDataset import ABIDEDataset 13 | from torch_geometric.data import DataLoader 14 | from net.braingnn import Network 15 | from imports.utils import train_val_test_split 16 | from sklearn.metrics import classification_report, confusion_matrix 17 | 18 | torch.manual_seed(123) 19 | 20 | EPS = 1e-10 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--epoch', type=int, default=0, help='starting epoch') 26 | parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training') 27 | parser.add_argument('--batchSize', type=int, default=100, help='size of the batches') 28 | parser.add_argument('--dataroot', type=str, default='/home/azureuser/projects/BrainGNN/data/ABIDE_pcp/cpac/filt_noglobal', help='root directory of the dataset') 29 | parser.add_argument('--fold', type=int, default=0, help='training which fold') 30 | parser.add_argument('--lr', type = float, default=0.01, help='learning rate') 31 | parser.add_argument('--stepsize', type=int, default=20, help='scheduler step size') 32 | parser.add_argument('--gamma', type=float, default=0.5, help='scheduler shrinking rate') 33 | parser.add_argument('--weightdecay', type=float, default=5e-3, help='regularization') 34 | parser.add_argument('--lamb0', type=float, default=1, help='classification loss weight') 35 | parser.add_argument('--lamb1', type=float, default=0, help='s1 unit regularization') 36 | parser.add_argument('--lamb2', type=float, default=0, help='s2 unit regularization') 37 | parser.add_argument('--lamb3', type=float, default=0.1, help='s1 entropy regularization') 38 | parser.add_argument('--lamb4', type=float, default=0.1, help='s2 entropy regularization') 39 | parser.add_argument('--lamb5', type=float, default=0.1, help='s1 consistence regularization') 40 | parser.add_argument('--layer', type=int, default=2, help='number of GNN layers') 41 | parser.add_argument('--ratio', type=float, default=0.5, help='pooling ratio') 42 | parser.add_argument('--indim', type=int, default=200, help='feature dim') 43 | parser.add_argument('--nroi', type=int, default=200, help='num of ROIs') 44 | parser.add_argument('--nclass', type=int, default=2, help='num of classes') 45 | parser.add_argument('--load_model', type=bool, default=False) 46 | parser.add_argument('--save_model', type=bool, default=True) 47 | parser.add_argument('--optim', type=str, default='Adam', help='optimization method: SGD, Adam') 48 | parser.add_argument('--save_path', type=str, default='./model/', help='path to save model') 49 | opt = parser.parse_args() 50 | 51 | if not os.path.exists(opt.save_path): 52 | os.makedirs(opt.save_path) 53 | 54 | #################### Parameter Initialization ####################### 55 | path = opt.dataroot 56 | name = 'ABIDE' 57 | save_model = opt.save_model 58 | load_model = opt.load_model 59 | opt_method = opt.optim 60 | num_epoch = opt.n_epochs 61 | fold = opt.fold 62 | writer = SummaryWriter(os.path.join('./log',str(fold))) 63 | 64 | 65 | 66 | ################## Define Dataloader ################################## 67 | 68 | dataset = ABIDEDataset(path,name) 69 | dataset.data.y = dataset.data.y.squeeze() 70 | dataset.data.x[dataset.data.x == float('inf')] = 0 71 | 72 | tr_index,val_index,te_index = train_val_test_split(fold=fold) 73 | train_dataset = dataset[tr_index] 74 | val_dataset = dataset[val_index] 75 | test_dataset = dataset[te_index] 76 | 77 | 78 | train_loader = DataLoader(train_dataset,batch_size=opt.batchSize, shuffle= True) 79 | val_loader = DataLoader(val_dataset, batch_size=opt.batchSize, shuffle=False) 80 | test_loader = DataLoader(test_dataset, batch_size=opt.batchSize, shuffle=False) 81 | 82 | 83 | 84 | ############### Define Graph Deep Learning Network ########################## 85 | model = Network(opt.indim,opt.ratio,opt.nclass).to(device) 86 | print(model) 87 | 88 | if opt_method == 'Adam': 89 | optimizer = torch.optim.Adam(model.parameters(), lr= opt.lr, weight_decay=opt.weightdecay) 90 | elif opt_method == 'SGD': 91 | optimizer = torch.optim.SGD(model.parameters(), lr =opt.lr, momentum = 0.9, weight_decay=opt.weightdecay, nesterov = True) 92 | 93 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.stepsize, gamma=opt.gamma) 94 | 95 | ############################### Define Other Loss Functions ######################################## 96 | def topk_loss(s,ratio): 97 | if ratio > 0.5: 98 | ratio = 1-ratio 99 | s = s.sort(dim=1).values 100 | res = -torch.log(s[:,-int(s.size(1)*ratio):]+EPS).mean() -torch.log(1-s[:,:int(s.size(1)*ratio)]+EPS).mean() 101 | return res 102 | 103 | 104 | def consist_loss(s): 105 | if len(s) == 0: 106 | return 0 107 | s = torch.sigmoid(s) 108 | W = torch.ones(s.shape[0],s.shape[0]) 109 | D = torch.eye(s.shape[0])*torch.sum(W,dim=1) 110 | L = D-W 111 | L = L.to(device) 112 | res = torch.trace(torch.transpose(s,0,1) @ L @ s)/(s.shape[0]*s.shape[0]) 113 | return res 114 | 115 | ###################### Network Training Function##################################### 116 | def train(epoch): 117 | print('train...........') 118 | scheduler.step() 119 | 120 | for param_group in optimizer.param_groups: 121 | print("LR", param_group['lr']) 122 | model.train() 123 | s1_list = [] 124 | s2_list = [] 125 | loss_all = 0 126 | step = 0 127 | for data in train_loader: 128 | data = data.to(device) 129 | optimizer.zero_grad() 130 | output, w1, w2, s1, s2 = model(data.x, data.edge_index, data.batch, data.edge_attr, data.pos) 131 | s1_list.append(s1.view(-1).detach().cpu().numpy()) 132 | s2_list.append(s2.view(-1).detach().cpu().numpy()) 133 | 134 | loss_c = F.nll_loss(output, data.y) 135 | 136 | loss_p1 = (torch.norm(w1, p=2)-1) ** 2 137 | loss_p2 = (torch.norm(w2, p=2)-1) ** 2 138 | loss_tpk1 = topk_loss(s1,opt.ratio) 139 | loss_tpk2 = topk_loss(s2,opt.ratio) 140 | loss_consist = 0 141 | for c in range(opt.nclass): 142 | loss_consist += consist_loss(s1[data.y == c]) 143 | loss = opt.lamb0*loss_c + opt.lamb1 * loss_p1 + opt.lamb2 * loss_p2 \ 144 | + opt.lamb3 * loss_tpk1 + opt.lamb4 *loss_tpk2 + opt.lamb5* loss_consist 145 | writer.add_scalar('train/classification_loss', loss_c, epoch*len(train_loader)+step) 146 | writer.add_scalar('train/unit_loss1', loss_p1, epoch*len(train_loader)+step) 147 | writer.add_scalar('train/unit_loss2', loss_p2, epoch*len(train_loader)+step) 148 | writer.add_scalar('train/TopK_loss1', loss_tpk1, epoch*len(train_loader)+step) 149 | writer.add_scalar('train/TopK_loss2', loss_tpk2, epoch*len(train_loader)+step) 150 | writer.add_scalar('train/GCL_loss', loss_consist, epoch*len(train_loader)+step) 151 | step = step + 1 152 | 153 | loss.backward() 154 | loss_all += loss.item() * data.num_graphs 155 | optimizer.step() 156 | 157 | s1_arr = np.hstack(s1_list) 158 | s2_arr = np.hstack(s2_list) 159 | return loss_all / len(train_dataset), s1_arr, s2_arr ,w1,w2 160 | 161 | 162 | ###################### Network Testing Function##################################### 163 | def test_acc(loader): 164 | model.eval() 165 | correct = 0 166 | for data in loader: 167 | data = data.to(device) 168 | outputs= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos) 169 | pred = outputs[0].max(dim=1)[1] 170 | correct += pred.eq(data.y).sum().item() 171 | 172 | return correct / len(loader.dataset) 173 | 174 | def test_loss(loader,epoch): 175 | print('testing...........') 176 | model.eval() 177 | loss_all = 0 178 | for data in loader: 179 | data = data.to(device) 180 | output, w1, w2, s1, s2= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos) 181 | loss_c = F.nll_loss(output, data.y) 182 | 183 | loss_p1 = (torch.norm(w1, p=2)-1) ** 2 184 | loss_p2 = (torch.norm(w2, p=2)-1) ** 2 185 | loss_tpk1 = topk_loss(s1,opt.ratio) 186 | loss_tpk2 = topk_loss(s2,opt.ratio) 187 | loss_consist = 0 188 | for c in range(opt.nclass): 189 | loss_consist += consist_loss(s1[data.y == c]) 190 | loss = opt.lamb0*loss_c + opt.lamb1 * loss_p1 + opt.lamb2 * loss_p2 \ 191 | + opt.lamb3 * loss_tpk1 + opt.lamb4 *loss_tpk2 + opt.lamb5* loss_consist 192 | 193 | loss_all += loss.item() * data.num_graphs 194 | return loss_all / len(loader.dataset) 195 | 196 | ####################################################################################### 197 | ############################ Model Training ######################################### 198 | ####################################################################################### 199 | best_model_wts = copy.deepcopy(model.state_dict()) 200 | best_loss = 1e10 201 | for epoch in range(0, num_epoch): 202 | since = time.time() 203 | tr_loss, s1_arr, s2_arr, w1, w2 = train(epoch) 204 | tr_acc = test_acc(train_loader) 205 | val_acc = test_acc(val_loader) 206 | val_loss = test_loss(val_loader,epoch) 207 | time_elapsed = time.time() - since 208 | print('*====**') 209 | print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 210 | print('Epoch: {:03d}, Train Loss: {:.7f}, ' 211 | 'Train Acc: {:.7f}, Test Loss: {:.7f}, Test Acc: {:.7f}'.format(epoch, tr_loss, 212 | tr_acc, val_loss, val_acc)) 213 | 214 | writer.add_scalars('Acc',{'train_acc':tr_acc,'val_acc':val_acc}, epoch) 215 | writer.add_scalars('Loss', {'train_loss': tr_loss, 'val_loss': val_loss}, epoch) 216 | writer.add_histogram('Hist/hist_s1', s1_arr, epoch) 217 | writer.add_histogram('Hist/hist_s2', s2_arr, epoch) 218 | 219 | if val_loss < best_loss and epoch > 5: 220 | print("saving best model") 221 | best_loss = val_loss 222 | best_model_wts = copy.deepcopy(model.state_dict()) 223 | if save_model: 224 | torch.save(best_model_wts, os.path.join(opt.save_path,str(fold)+'.pth')) 225 | 226 | ####################################################################################### 227 | ######################### Testing on testing set ###################################### 228 | ####################################################################################### 229 | 230 | if opt.load_model: 231 | model = Network(opt.indim,opt.ratio,opt.nclass).to(device) 232 | model.load_state_dict(torch.load(os.path.join(opt.save_path,str(fold)+'.pth'))) 233 | model.eval() 234 | preds = [] 235 | correct = 0 236 | for data in val_loader: 237 | data = data.to(device) 238 | outputs= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos) 239 | pred = outputs[0].max(1)[1] 240 | preds.append(pred.cpu().detach().numpy()) 241 | correct += pred.eq(data.y).sum().item() 242 | preds = np.concatenate(preds,axis=0) 243 | trues = val_dataset.data.y.cpu().detach().numpy() 244 | cm = confusion_matrix(trues,preds) 245 | print("Confusion matrix") 246 | print(classification_report(trues, preds)) 247 | 248 | else: 249 | model.load_state_dict(best_model_wts) 250 | model.eval() 251 | test_accuracy = test_acc(test_loader) 252 | test_l= test_loss(test_loader,0) 253 | print("===========================") 254 | print("Test Acc: {:.7f}, Test Loss: {:.7f} ".format(test_accuracy, test_l)) 255 | print(opt) 256 | 257 | -------------------------------------------------------------------------------- /data/subject_ID.txt: -------------------------------------------------------------------------------- 1 | 50128 2 | 51203 3 | 50325 4 | 50117 5 | 50573 6 | 50741 7 | 50779 8 | 51009 9 | 50746 10 | 50574 11 | 50110 12 | 50322 13 | 51036 14 | 51204 15 | 50119 16 | 50126 17 | 50314 18 | 51490 19 | 50784 20 | 51464 21 | 51000 22 | 51038 23 | 50748 24 | 51235 25 | 51007 26 | 51463 27 | 50783 28 | 50777 29 | 50313 30 | 50121 31 | 51053 32 | 51261 33 | 50723 34 | 50511 35 | 51295 36 | 50347 37 | 50982 38 | 50976 39 | 51098 40 | 51292 41 | 50340 42 | 50516 43 | 50724 44 | 51266 45 | 51054 46 | 50186 47 | 50529 48 | 50985 49 | 50520 50 | 50376 51 | 50978 52 | 50144 53 | 51096 54 | 50382 55 | 51250 56 | 51062 57 | 50349 58 | 51065 59 | 50385 60 | 51257 61 | 50143 62 | 51091 63 | 50371 64 | 50527 65 | 51268 66 | 50188 67 | 50518 68 | 50749 69 | 51039 70 | 50776 71 | 50120 72 | 50312 73 | 51006 74 | 51234 75 | 50782 76 | 51462 77 | 50118 78 | 51465 79 | 50785 80 | 51001 81 | 50315 82 | 50127 83 | 51491 84 | 51008 85 | 50778 86 | 51205 87 | 50575 88 | 50747 89 | 50111 90 | 50129 91 | 50116 92 | 50324 93 | 50740 94 | 50572 95 | 51030 96 | 51202 97 | 50370 98 | 50142 99 | 51090 100 | 50526 101 | 51256 102 | 51064 103 | 50519 104 | 50189 105 | 51269 106 | 51063 107 | 50383 108 | 51251 109 | 50521 110 | 50145 111 | 51097 112 | 50979 113 | 50377 114 | 50348 115 | 51055 116 | 50187 117 | 51267 118 | 51293 119 | 50341 120 | 50725 121 | 51258 122 | 50984 123 | 50528 124 | 50970 125 | 50510 126 | 50722 127 | 51294 128 | 50346 129 | 51260 130 | 51052 131 | 51099 132 | 50977 133 | 50379 134 | 50983 135 | 50039 136 | 50496 137 | 51312 138 | 50234 139 | 50006 140 | 50650 141 | 50802 142 | 50668 143 | 51118 144 | 50657 145 | 50233 146 | 51127 147 | 51315 148 | 50491 149 | 50008 150 | 50498 151 | 50037 152 | 50205 153 | 50661 154 | 51581 155 | 50453 156 | 50695 157 | 51575 158 | 51111 159 | 51323 160 | 51129 161 | 50659 162 | 51324 163 | 51116 164 | 51572 165 | 50692 166 | 50666 167 | 50202 168 | 50030 169 | 51142 170 | 51370 171 | 50269 172 | 51189 173 | 50251 174 | 50407 175 | 50438 176 | 51348 177 | 50603 178 | 50267 179 | 51187 180 | 50055 181 | 51341 182 | 50293 183 | 51173 184 | 51174 185 | 51346 186 | 50294 187 | 51180 188 | 50052 189 | 50260 190 | 50604 191 | 50436 192 | 50658 193 | 51128 194 | 50667 195 | 50455 196 | 50031 197 | 50203 198 | 51117 199 | 51325 200 | 50693 201 | 51573 202 | 50499 203 | 50009 204 | 51574 205 | 50694 206 | 51322 207 | 51110 208 | 50204 209 | 50036 210 | 51580 211 | 50660 212 | 50803 213 | 50669 214 | 51314 215 | 51126 216 | 50490 217 | 50656 218 | 50232 219 | 50038 220 | 50804 221 | 50007 222 | 50235 223 | 50651 224 | 50463 225 | 50497 226 | 51121 227 | 51313 228 | 50261 229 | 51181 230 | 50053 231 | 50437 232 | 50605 233 | 51347 234 | 50295 235 | 51175 236 | 50408 237 | 51172 238 | 51340 239 | 50292 240 | 50602 241 | 51186 242 | 50054 243 | 50266 244 | 50259 245 | 50250 246 | 50406 247 | 51349 248 | 50439 249 | 50257 250 | 51188 251 | 50268 252 | 51195 253 | 50047 254 | 50275 255 | 50611 256 | 51161 257 | 51353 258 | 50281 259 | 51159 260 | 51354 261 | 50286 262 | 51166 263 | 50424 264 | 50616 265 | 50272 266 | 51192 267 | 50040 268 | 50049 269 | 51362 270 | 51150 271 | 50412 272 | 50620 273 | 50618 274 | 50288 275 | 51168 276 | 50627 277 | 50415 278 | 50243 279 | 51365 280 | 50441 281 | 50217 282 | 50025 283 | 50819 284 | 51331 285 | 51103 286 | 51567 287 | 50687 288 | 50826 289 | 51558 290 | 51560 291 | 51104 292 | 51336 293 | 50022 294 | 50210 295 | 50446 296 | 51309 297 | 50821 298 | 51132 299 | 51300 300 | 51556 301 | 50642 302 | 50470 303 | 50014 304 | 51569 305 | 50689 306 | 50817 307 | 50013 308 | 50477 309 | 50645 310 | 50483 311 | 51307 312 | 51135 313 | 50448 314 | 51338 315 | 51169 316 | 50289 317 | 50619 318 | 51364 319 | 51156 320 | 50414 321 | 50626 322 | 50242 323 | 50048 324 | 50245 325 | 50621 326 | 50413 327 | 51151 328 | 51363 329 | 50628 330 | 50617 331 | 50425 332 | 51193 333 | 50041 334 | 50273 335 | 51167 336 | 51355 337 | 50287 338 | 51352 339 | 50280 340 | 51160 341 | 50274 342 | 51194 343 | 50046 344 | 50422 345 | 50610 346 | 50482 347 | 51134 348 | 51306 349 | 50012 350 | 50644 351 | 51339 352 | 50449 353 | 50643 354 | 50015 355 | 51301 356 | 51133 357 | 50485 358 | 51557 359 | 50816 360 | 50688 361 | 51568 362 | 50211 363 | 50023 364 | 50447 365 | 51561 366 | 51105 367 | 50820 368 | 51308 369 | 51102 370 | 51330 371 | 50686 372 | 51566 373 | 50440 374 | 50818 375 | 50024 376 | 50216 377 | 51559 378 | 50169 379 | 50955 380 | 50156 381 | 51084 382 | 50364 383 | 50700 384 | 50532 385 | 51070 386 | 50390 387 | 51048 388 | 50952 389 | 50738 390 | 50397 391 | 51077 392 | 50999 393 | 50707 394 | 50363 395 | 51083 396 | 50990 397 | 50158 398 | 50964 399 | 51273 400 | 51041 401 | 50193 402 | 50355 403 | 50167 404 | 50503 405 | 50731 406 | 50709 407 | 50399 408 | 51079 409 | 50997 410 | 50736 411 | 50504 412 | 50160 413 | 51280 414 | 50352 415 | 51046 416 | 50194 417 | 51274 418 | 51482 419 | 50306 420 | 50134 421 | 51220 422 | 51012 423 | 51476 424 | 50796 425 | 50339 426 | 50791 427 | 51471 428 | 51015 429 | 51227 430 | 50133 431 | 50301 432 | 50557 433 | 51485 434 | 51218 435 | 50568 436 | 51023 437 | 51211 438 | 50753 439 | 50561 440 | 50105 441 | 50337 442 | 51478 443 | 50798 444 | 50308 445 | 50330 446 | 50102 447 | 50566 448 | 50754 449 | 51216 450 | 51024 451 | 50559 452 | 51229 453 | 50996 454 | 51078 455 | 50962 456 | 50708 457 | 51275 458 | 51047 459 | 50195 460 | 50505 461 | 50737 462 | 51281 463 | 50353 464 | 50161 465 | 50965 466 | 50159 467 | 50991 468 | 50166 469 | 50354 470 | 50730 471 | 50502 472 | 51040 473 | 50192 474 | 51272 475 | 50739 476 | 51049 477 | 50706 478 | 50150 479 | 51082 480 | 50362 481 | 50998 482 | 51076 483 | 50954 484 | 50168 485 | 50391 486 | 51071 487 | 50365 488 | 50157 489 | 51085 490 | 50701 491 | 51025 492 | 51217 493 | 50103 494 | 50331 495 | 50755 496 | 50567 497 | 51228 498 | 50558 499 | 50560 500 | 50752 501 | 50336 502 | 50104 503 | 51210 504 | 50799 505 | 51479 506 | 50300 507 | 50132 508 | 50556 509 | 51484 510 | 51470 511 | 50790 512 | 51226 513 | 51014 514 | 50569 515 | 51219 516 | 51013 517 | 51221 518 | 50797 519 | 51477 520 | 50551 521 | 51483 522 | 50135 523 | 50307 524 | 50338 525 | 50171 526 | 50343 527 | 51291 528 | 50727 529 | 50515 530 | 50185 531 | 51057 532 | 51265 533 | 50972 534 | 50388 535 | 50986 536 | 51068 537 | 51262 538 | 50182 539 | 51050 540 | 51606 541 | 50344 542 | 51296 543 | 50981 544 | 50149 545 | 51254 546 | 50386 547 | 50988 548 | 51066 549 | 50372 550 | 50524 551 | 51059 552 | 50711 553 | 50523 554 | 51095 555 | 50147 556 | 50375 557 | 51061 558 | 51253 559 | 50381 560 | 51298 561 | 51238 562 | 50577 563 | 50745 564 | 50321 565 | 50113 566 | 51207 567 | 51035 568 | 51469 569 | 50789 570 | 50319 571 | 51456 572 | 51032 573 | 50114 574 | 50326 575 | 50742 576 | 50570 577 | 51209 578 | 51236 579 | 50780 580 | 51460 581 | 50774 582 | 50122 583 | 50310 584 | 51458 585 | 50317 586 | 50125 587 | 51493 588 | 50773 589 | 51467 590 | 50787 591 | 51231 592 | 51003 593 | 51252 594 | 50380 595 | 51060 596 | 50710 597 | 50374 598 | 51094 599 | 50146 600 | 51299 601 | 51093 602 | 50373 603 | 50525 604 | 51067 605 | 50989 606 | 51255 607 | 50387 608 | 50728 609 | 51058 610 | 50345 611 | 51297 612 | 50183 613 | 51051 614 | 51263 615 | 51607 616 | 50148 617 | 50974 618 | 51264 619 | 50184 620 | 51056 621 | 50342 622 | 50170 623 | 50514 624 | 50726 625 | 51069 626 | 50987 627 | 50973 628 | 51459 629 | 50329 630 | 50786 631 | 51466 632 | 51002 633 | 51230 634 | 50124 635 | 50316 636 | 50772 637 | 51492 638 | 50578 639 | 51208 640 | 50775 641 | 50311 642 | 50123 643 | 51237 644 | 51461 645 | 50781 646 | 50318 647 | 50788 648 | 51468 649 | 50327 650 | 50115 651 | 50571 652 | 50743 653 | 51457 654 | 51201 655 | 51033 656 | 51239 657 | 51034 658 | 51206 659 | 50744 660 | 50576 661 | 50112 662 | 50320 663 | 50060 664 | 50252 665 | 50404 666 | 51146 667 | 50609 668 | 50299 669 | 51179 670 | 51373 671 | 51141 672 | 50403 673 | 50255 674 | 50058 675 | 50297 676 | 51345 677 | 51177 678 | 50263 679 | 50051 680 | 51183 681 | 50435 682 | 50607 683 | 51148 684 | 50056 685 | 51184 686 | 50264 687 | 51170 688 | 50290 689 | 51342 690 | 50801 691 | 51329 692 | 50466 693 | 50654 694 | 51316 695 | 51124 696 | 50492 697 | 51578 698 | 50698 699 | 50208 700 | 51123 701 | 51311 702 | 50005 703 | 50237 704 | 50653 705 | 51318 706 | 50468 707 | 51327 708 | 50691 709 | 51571 710 | 50665 711 | 51585 712 | 50033 713 | 50201 714 | 50239 715 | 50206 716 | 50034 717 | 51582 718 | 51576 719 | 50696 720 | 51320 721 | 51112 722 | 50291 723 | 51343 724 | 51171 725 | 50433 726 | 50601 727 | 50265 728 | 50057 729 | 51185 730 | 50050 731 | 51182 732 | 50262 733 | 50606 734 | 50434 735 | 50296 736 | 51344 737 | 51149 738 | 50402 739 | 50254 740 | 51140 741 | 50059 742 | 51147 743 | 50253 744 | 50405 745 | 51178 746 | 50298 747 | 50608 748 | 50697 749 | 51577 750 | 51113 751 | 51321 752 | 50035 753 | 50207 754 | 50663 755 | 51583 756 | 50469 757 | 51319 758 | 51584 759 | 50664 760 | 50200 761 | 50032 762 | 51326 763 | 51114 764 | 51570 765 | 50690 766 | 50807 767 | 50209 768 | 50699 769 | 51579 770 | 50236 771 | 50004 772 | 50652 773 | 50494 774 | 51122 775 | 51328 776 | 50800 777 | 51317 778 | 50493 779 | 50655 780 | 50467 781 | 50003 782 | 51563 783 | 50683 784 | 51335 785 | 51107 786 | 50213 787 | 50445 788 | 51138 789 | 50648 790 | 50822 791 | 50442 792 | 50026 793 | 50214 794 | 51100 795 | 51332 796 | 51564 797 | 50019 798 | 50825 799 | 50489 800 | 50010 801 | 50646 802 | 50480 803 | 51136 804 | 51304 805 | 51109 806 | 51303 807 | 51131 808 | 50487 809 | 50017 810 | 50028 811 | 50814 812 | 50418 813 | 51165 814 | 50285 815 | 51357 816 | 50615 817 | 50427 818 | 50043 819 | 51191 820 | 50271 821 | 50249 822 | 50276 823 | 50044 824 | 51196 825 | 50612 826 | 50282 827 | 51350 828 | 51162 829 | 51359 830 | 50416 831 | 50624 832 | 50240 833 | 51154 834 | 50278 835 | 51198 836 | 51153 837 | 51361 838 | 50247 839 | 50623 840 | 50411 841 | 50016 842 | 51130 843 | 51302 844 | 50486 845 | 50815 846 | 50029 847 | 50481 848 | 51305 849 | 51137 850 | 50011 851 | 50647 852 | 50812 853 | 51333 854 | 51101 855 | 51565 856 | 50685 857 | 50443 858 | 50215 859 | 50027 860 | 50488 861 | 50824 862 | 50020 863 | 50212 864 | 50444 865 | 50682 866 | 51562 867 | 51106 868 | 51334 869 | 50649 870 | 50823 871 | 51139 872 | 51199 873 | 50279 874 | 50246 875 | 50410 876 | 50622 877 | 51360 878 | 51152 879 | 51358 880 | 50428 881 | 51155 882 | 50625 883 | 50417 884 | 50241 885 | 50248 886 | 51163 887 | 50283 888 | 51351 889 | 50045 890 | 51197 891 | 50277 892 | 50613 893 | 50421 894 | 50419 895 | 51369 896 | 50426 897 | 50614 898 | 50270 899 | 50042 900 | 51190 901 | 50284 902 | 51356 903 | 51164 904 | 51472 905 | 50792 906 | 51224 907 | 51016 908 | 50302 909 | 50130 910 | 51486 911 | 50554 912 | 51029 913 | 51481 914 | 50553 915 | 50305 916 | 51011 917 | 51223 918 | 50795 919 | 50333 920 | 50757 921 | 50565 922 | 51027 923 | 51215 924 | 51488 925 | 51018 926 | 51212 927 | 51020 928 | 50562 929 | 50750 930 | 50334 931 | 50106 932 | 51279 933 | 50199 934 | 50509 935 | 51074 936 | 50704 937 | 51080 938 | 50152 939 | 50360 940 | 50956 941 | 50358 942 | 50367 943 | 51087 944 | 50969 945 | 50531 946 | 50703 947 | 51241 948 | 51073 949 | 50960 950 | 50994 951 | 51248 952 | 50507 953 | 50735 954 | 50351 955 | 50163 956 | 51277 957 | 50197 958 | 51045 959 | 50993 960 | 50369 961 | 51089 962 | 50967 963 | 50190 964 | 51042 965 | 50164 966 | 50958 967 | 50356 968 | 50732 969 | 50500 970 | 50751 971 | 50563 972 | 50107 973 | 50335 974 | 51021 975 | 51213 976 | 51214 977 | 51026 978 | 50332 979 | 50564 980 | 50756 981 | 51019 982 | 51489 983 | 51222 984 | 51010 985 | 51474 986 | 50794 987 | 51480 988 | 50552 989 | 50304 990 | 50136 991 | 50109 992 | 50131 993 | 50303 994 | 51487 995 | 50555 996 | 50793 997 | 51473 998 | 51017 999 | 51225 1000 | 51028 1001 | 50966 1002 | 51088 1003 | 50368 1004 | 50992 1005 | 50357 1006 | 50959 1007 | 50501 1008 | 50733 1009 | 51271 1010 | 50191 1011 | 51249 1012 | 50995 1013 | 50961 1014 | 50196 1015 | 51044 1016 | 51276 1017 | 50162 1018 | 50350 1019 | 51282 1020 | 50359 1021 | 50957 1022 | 51072 1023 | 51240 1024 | 50968 1025 | 51086 1026 | 50366 1027 | 50702 1028 | 50530 1029 | 50198 1030 | 51278 1031 | 50705 1032 | 50361 1033 | 51081 1034 | 50153 1035 | 51075 1036 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster @ file:///home/ktietz/src/ci/alabaster_1611921544520/work 2 | anaconda-client==1.7.2 3 | anaconda-project @ file:///tmp/build/80754af9/anaconda-project_1610472525955/work 4 | anyio @ file:///tmp/build/80754af9/anyio_1617783275907/work/dist 5 | appdirs==1.4.4 6 | argh==0.26.2 7 | argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613037097816/work 8 | arrow==0.13.1 9 | ase==3.21.1 10 | asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work 11 | astroid @ file:///tmp/build/80754af9/astroid_1613500854201/work 12 | astropy @ file:///tmp/build/80754af9/astropy_1617745353437/work 13 | async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work 14 | atomicwrites==1.4.0 15 | attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work 16 | autopep8 @ file:///tmp/build/80754af9/autopep8_1615918855173/work 17 | Babel @ file:///tmp/build/80754af9/babel_1607110387436/work 18 | backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work 19 | backports.shutil-get-terminal-size @ file:///tmp/build/80754af9/backports.shutil_get_terminal_size_1608222128777/work 20 | beautifulsoup4 @ file:///home/linux1/recipes/ci/beautifulsoup4_1610988766420/work 21 | binaryornot @ file:///tmp/build/80754af9/binaryornot_1617751525010/work 22 | bitarray @ file:///tmp/build/80754af9/bitarray_1618431750766/work 23 | bkcharts==0.2 24 | black==19.10b0 25 | bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work 26 | bokeh @ file:///tmp/build/80754af9/bokeh_1617824541184/work 27 | boto==2.49.0 28 | Bottleneck==1.3.2 29 | brotlipy==0.7.0 30 | certifi==2020.12.5 31 | cffi @ file:///tmp/build/80754af9/cffi_1613246945912/work 32 | chardet @ file:///tmp/build/80754af9/chardet_1607706746162/work 33 | click @ file:///home/linux1/recipes/ci/click_1610990599742/work 34 | cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work 35 | clyent==1.2.2 36 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work 37 | contextlib2==0.6.0.post1 38 | cookiecutter @ file:///tmp/build/80754af9/cookiecutter_1617748928239/work 39 | cryptography @ file:///tmp/build/80754af9/cryptography_1616769286105/work 40 | cycler==0.10.0 41 | Cython @ file:///tmp/build/80754af9/cython_1618435160151/work 42 | cytoolz==0.11.0 43 | dask @ file:///tmp/build/80754af9/dask-core_1617390489108/work 44 | decorator @ file:///tmp/build/80754af9/decorator_1617916966915/work 45 | deepdish==0.3.6 46 | defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work 47 | diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work 48 | distributed @ file:///tmp/build/80754af9/distributed_1617381497899/work 49 | docutils @ file:///tmp/build/80754af9/docutils_1617624660125/work 50 | entrypoints==0.3 51 | et-xmlfile==1.0.1 52 | fastcache==1.1.0 53 | filelock @ file:///home/linux1/recipes/ci/filelock_1610993975404/work 54 | flake8 @ file:///tmp/build/80754af9/flake8_1615834841867/work 55 | Flask @ file:///home/ktietz/src/ci/flask_1611932660458/work 56 | fsspec @ file:///tmp/build/80754af9/fsspec_1617959894824/work 57 | future==0.18.2 58 | gevent @ file:///tmp/build/80754af9/gevent_1616770671827/work 59 | glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work 60 | gmpy2==2.0.8 61 | googledrivedownloader==0.4 62 | greenlet @ file:///tmp/build/80754af9/greenlet_1611957705398/work 63 | h5py==2.10.0 64 | HeapDict==1.0.1 65 | html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work 66 | idna @ file:///home/linux1/recipes/ci/idna_1610986105248/work 67 | imageio @ file:///tmp/build/80754af9/imageio_1617700267927/work 68 | imagesize @ file:///home/ktietz/src/ci/imagesize_1611921604382/work 69 | importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617874469820/work 70 | inflection==0.5.1 71 | iniconfig @ file:///home/linux1/recipes/ci/iniconfig_1610983019677/work 72 | intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work 73 | ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl 74 | ipython @ file:///tmp/build/80754af9/ipython_1617120885885/work 75 | ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work 76 | ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work 77 | isodate==0.6.0 78 | isort @ file:///tmp/build/80754af9/isort_1616355431277/work 79 | itsdangerous @ file:///home/ktietz/src/ci/itsdangerous_1611932585308/work 80 | jdcal==1.4.1 81 | jedi @ file:///tmp/build/80754af9/jedi_1606932564285/work 82 | jeepney @ file:///tmp/build/80754af9/jeepney_1606148855031/work 83 | Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work 84 | jinja2-time @ file:///tmp/build/80754af9/jinja2-time_1617751524098/work 85 | joblib @ file:///tmp/build/80754af9/joblib_1613502643832/work 86 | json5==0.9.5 87 | jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work 88 | jupyter==1.0.0 89 | jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work 90 | jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work 91 | jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213311222/work 92 | jupyter-packaging @ file:///tmp/build/80754af9/jupyter-packaging_1613502826984/work 93 | jupyter-server @ file:///tmp/build/80754af9/jupyter_server_1616083640759/work 94 | jupyterlab @ file:///tmp/build/80754af9/jupyterlab_1619133235951/work 95 | jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work 96 | jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1617134334258/work 97 | jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work 98 | keyring @ file:///tmp/build/80754af9/keyring_1614616740399/work 99 | kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work 100 | lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1616526917483/work 101 | libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work 102 | llvmlite==0.36.0 103 | locket==0.2.1 104 | lxml @ file:///tmp/build/80754af9/lxml_1616443220220/work 105 | MarkupSafe==1.1.1 106 | matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1613407855456/work 107 | mccabe==0.6.1 108 | mistune==0.8.4 109 | mkl-fft==1.3.0 110 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853849286/work 111 | mkl-service==2.3.0 112 | mock @ file:///tmp/build/80754af9/mock_1607622725907/work 113 | more-itertools @ file:///tmp/build/80754af9/more-itertools_1613676688952/work 114 | mpmath==1.2.1 115 | msgpack @ file:///tmp/build/80754af9/msgpack-python_1612287151062/work 116 | multipledispatch==0.6.0 117 | mypy-extensions==0.4.3 118 | nbclassic @ file:///tmp/build/80754af9/nbclassic_1616085367084/work 119 | nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work 120 | nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914830498/work 121 | nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work 122 | nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work 123 | networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work 124 | nibabel==3.2.1 125 | nilearn==0.7.1 126 | nltk @ file:///tmp/build/80754af9/nltk_1618327084230/work 127 | nose @ file:///tmp/build/80754af9/nose_1606773131901/work 128 | notebook @ file:///tmp/build/80754af9/notebook_1616443462982/work 129 | numba @ file:///tmp/build/80754af9/numba_1616774046117/work 130 | numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work 131 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1618497241363/work 132 | numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work 133 | olefile==0.46 134 | openpyxl @ file:///tmp/build/80754af9/openpyxl_1615411699337/work 135 | packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work 136 | pandas==1.2.4 137 | pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work 138 | parso==0.7.0 139 | partd @ file:///tmp/build/80754af9/partd_1618000087440/work 140 | path @ file:///tmp/build/80754af9/path_1614022220526/work 141 | pathlib2 @ file:///tmp/build/80754af9/pathlib2_1607024983162/work 142 | pathspec==0.7.0 143 | pathtools==0.1.2 144 | patsy==0.5.1 145 | pep8==1.7.1 146 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work 147 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work 148 | Pillow @ file:///tmp/build/80754af9/pillow_1617383569452/work 149 | pkginfo==1.7.0 150 | pluggy @ file:///tmp/build/80754af9/pluggy_1615976321666/work 151 | ply==3.11 152 | poyo @ file:///tmp/build/80754af9/poyo_1617751526755/work 153 | prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1618088486455/work 154 | prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work 155 | protobuf==3.17.0 156 | psutil @ file:///tmp/build/80754af9/psutil_1612298023621/work 157 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 158 | py @ file:///tmp/build/80754af9/py_1607971587848/work 159 | pycodestyle @ file:///home/ktietz/src/ci_mi/pycodestyle_1612807597675/work 160 | pycosat==0.6.3 161 | pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work 162 | pycurl==7.43.0.6 163 | pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1616182067796/work 164 | pyerfa @ file:///tmp/build/80754af9/pyerfa_1619390903914/work 165 | pyflakes @ file:///home/ktietz/src/ci_ipy2/pyflakes_1612551159640/work 166 | Pygments @ file:///tmp/build/80754af9/pygments_1615143339740/work 167 | pylint @ file:///tmp/build/80754af9/pylint_1617135829881/work 168 | pyls-black @ file:///tmp/build/80754af9/pyls-black_1607553132291/work 169 | pyls-spyder @ file:///tmp/build/80754af9/pyls-spyder_1613849700860/work 170 | pyodbc===4.0.0-unsupported 171 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1608057966937/work 172 | pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work 173 | pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work 174 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work 175 | pytest==6.2.3 176 | python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work 177 | python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work 178 | python-language-server @ file:///tmp/build/80754af9/python-language-server_1607972495879/work 179 | python-louvain==0.15 180 | python-slugify @ file:///tmp/build/80754af9/python-slugify_1620405669636/work 181 | pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work 182 | PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work 183 | pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work 184 | PyYAML==5.4.1 185 | pyzmq==20.0.0 186 | QDarkStyle @ file:///tmp/build/80754af9/qdarkstyle_1617386714626/work 187 | qstylizer @ file:///tmp/build/80754af9/qstylizer_1617713584600/work/dist/qstylizer-0.1.10-py2.py3-none-any.whl 188 | QtAwesome @ file:///tmp/build/80754af9/qtawesome_1615991616277/work 189 | qtconsole @ file:///tmp/build/80754af9/qtconsole_1616775094278/work 190 | QtPy==1.9.0 191 | rdflib==5.0.0 192 | regex @ file:///tmp/build/80754af9/regex_1617569202463/work 193 | requests @ file:///tmp/build/80754af9/requests_1608241421344/work 194 | rope @ file:///tmp/build/80754af9/rope_1602264064449/work 195 | Rtree @ file:///tmp/build/80754af9/rtree_1618420845272/work 196 | ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016699510/work 197 | scikit-image==0.16.2 198 | scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1614446682169/work 199 | scipy @ file:///tmp/build/80754af9/scipy_1618855647378/work 200 | seaborn @ file:///tmp/build/80754af9/seaborn_1608578541026/work 201 | SecretStorage @ file:///tmp/build/80754af9/secretstorage_1614022784285/work 202 | Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work 203 | simplegeneric==0.8.1 204 | singledispatch @ file:///tmp/build/80754af9/singledispatch_1614366001199/work 205 | sip==4.19.13 206 | six @ file:///tmp/build/80754af9/six_1605205327372/work 207 | sniffio @ file:///tmp/build/80754af9/sniffio_1614030475067/work 208 | snowballstemmer @ file:///tmp/build/80754af9/snowballstemmer_1611258885636/work 209 | sortedcollections @ file:///tmp/build/80754af9/sortedcollections_1611172717284/work 210 | sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1606865132123/work 211 | soupsieve @ file:///tmp/build/80754af9/soupsieve_1616183228191/work 212 | Sphinx @ file:///tmp/build/80754af9/sphinx_1616268783226/work 213 | sphinxcontrib-applehelp @ file:///home/ktietz/src/ci/sphinxcontrib-applehelp_1611920841464/work 214 | sphinxcontrib-devhelp @ file:///home/ktietz/src/ci/sphinxcontrib-devhelp_1611920923094/work 215 | sphinxcontrib-htmlhelp @ file:///home/ktietz/src/ci/sphinxcontrib-htmlhelp_1611920974801/work 216 | sphinxcontrib-jsmath @ file:///home/ktietz/src/ci/sphinxcontrib-jsmath_1611920942228/work 217 | sphinxcontrib-qthelp @ file:///home/ktietz/src/ci/sphinxcontrib-qthelp_1611921055322/work 218 | sphinxcontrib-serializinghtml @ file:///home/ktietz/src/ci/sphinxcontrib-serializinghtml_1611920755253/work 219 | sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work 220 | spyder @ file:///tmp/build/80754af9/spyder_1618327905127/work 221 | spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1617396566288/work 222 | SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1618089170652/work 223 | statsmodels @ file:///tmp/build/80754af9/statsmodels_1614023746358/work 224 | sympy @ file:///tmp/build/80754af9/sympy_1618252284338/work 225 | tables==3.6.1 226 | tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work 227 | tensorboardX==2.2 228 | terminado==0.9.4 229 | testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work 230 | text-unidecode==1.3 231 | textdistance @ file:///tmp/build/80754af9/textdistance_1612461398012/work 232 | threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl 233 | three-merge @ file:///tmp/build/80754af9/three-merge_1607553261110/work 234 | tinycss @ file:///tmp/build/80754af9/tinycss_1617713798712/work 235 | toml @ file:///tmp/build/80754af9/toml_1616166611790/work 236 | toolz @ file:///home/linux1/recipes/ci/toolz_1610987900194/work 237 | torch==1.7.0 238 | torch-cluster==1.5.9 239 | torch-geometric==1.7.0 240 | torch-scatter==2.0.6 241 | torch-sparse==0.6.9 242 | torch-spline-conv==1.2.1 243 | torchaudio==0.7.0a0+ac17b64 244 | torchvision==0.8.0 245 | tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work 246 | tqdm @ file:///tmp/build/80754af9/tqdm_1615925068909/work 247 | traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work 248 | tsBNgen==1.0.0 249 | typed-ast @ file:///tmp/build/80754af9/typed-ast_1610484547928/work 250 | typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work 251 | ujson @ file:///tmp/build/80754af9/ujson_1611259522456/work 252 | unicodecsv==0.14.1 253 | Unidecode @ file:///tmp/build/80754af9/unidecode_1614712377438/work 254 | urllib3 @ file:///tmp/build/80754af9/urllib3_1615837158687/work 255 | watchdog @ file:///tmp/build/80754af9/watchdog_1612471027849/work 256 | wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work 257 | webencodings==0.5.1 258 | Werkzeug @ file:///home/ktietz/src/ci/werkzeug_1611932622770/work 259 | whichcraft @ file:///tmp/build/80754af9/whichcraft_1617751293875/work 260 | widgetsnbextension==3.5.1 261 | wrapt==1.12.1 262 | wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1617224664226/work 263 | xlrd @ file:///tmp/build/80754af9/xlrd_1608072521494/work 264 | XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1617224712951/work 265 | xlwt==1.3.0 266 | yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work 267 | zict==2.0.0 268 | zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work 269 | zope.event==4.5.0 270 | zope.interface @ file:///tmp/build/80754af9/zope.interface_1616357211867/work 271 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 213 | -------------------------------------------------------------------------------- /imports/gdc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numba 3 | import numpy as np 4 | from scipy.linalg import expm 5 | from torch_geometric.utils import add_self_loops, is_undirected, to_dense_adj 6 | from torch_sparse import coalesce 7 | from torch_scatter import scatter_add 8 | 9 | 10 | def jit(): 11 | def decorator(func): 12 | try: 13 | return numba.jit(cache=True)(func) 14 | except RuntimeError: 15 | return numba.jit(cache=False)(func) 16 | 17 | return decorator 18 | 19 | 20 | class GDC(object): 21 | r"""Processes the graph via Graph Diffusion Convolution (GDC) from the 22 | `"Diffusion Improves Graph Learning" `_ 23 | paper. 24 | .. note:: 25 | The paper offers additional advice on how to choose the 26 | hyperparameters. 27 | For an example of using GCN with GDC, see `examples/gcn.py 28 | `_. 30 | Args: 31 | self_loop_weight (float, optional): Weight of the added self-loop. 32 | Set to :obj:`None` to add no self-loops. (default: :obj:`1`) 33 | normalization_in (str, optional): Normalization of the transition 34 | matrix on the original (input) graph. Possible values: 35 | :obj:`"sym"`, :obj:`"col"`, and :obj:`"row"`. 36 | See :func:`GDC.transition_matrix` for details. 37 | (default: :obj:`"sym"`) 38 | normalization_out (str, optional): Normalization of the transition 39 | matrix on the transformed GDC (output) graph. Possible values: 40 | :obj:`"sym"`, :obj:`"col"`, :obj:`"row"`, and :obj:`None`. 41 | See :func:`GDC.transition_matrix` for details. 42 | (default: :obj:`"col"`) 43 | diffusion_kwargs (dict, optional): Dictionary containing the parameters 44 | for diffusion. 45 | `method` specifies the diffusion method (:obj:`"ppr"`, 46 | :obj:`"heat"` or :obj:`"coeff"`). 47 | Each diffusion method requires different additional parameters. 48 | See :func:`GDC.diffusion_matrix_exact` or 49 | :func:`GDC.diffusion_matrix_approx` for details. 50 | (default: :obj:`dict(method='ppr', alpha=0.15)`) 51 | sparsification_kwargs (dict, optional): Dictionary containing the 52 | parameters for sparsification. 53 | `method` specifies the sparsification method (:obj:`"threshold"` or 54 | :obj:`"topk"`). 55 | Each sparsification method requires different additional 56 | parameters. 57 | See :func:`GDC.sparsify_dense` for details. 58 | (default: :obj:`dict(method='threshold', avg_degree=64)`) 59 | exact (bool, optional): Whether to exactly calculate the diffusion 60 | matrix. 61 | Note that the exact variants are not scalable. 62 | They densify the adjacency matrix and calculate either its inverse 63 | or its matrix exponential. 64 | However, the approximate variants do not support edge weights and 65 | currently only personalized PageRank and sparsification by 66 | threshold are implemented as fast, approximate versions. 67 | (default: :obj:`True`) 68 | :rtype: :class:`torch_geometric.data.Data` 69 | """ 70 | def __init__(self, self_loop_weight=1, normalization_in='sym', 71 | normalization_out='col', 72 | diffusion_kwargs=dict(method='ppr', alpha=0.15), 73 | sparsification_kwargs=dict(method='threshold', 74 | avg_degree=64), exact=True): 75 | self.self_loop_weight = self_loop_weight 76 | self.normalization_in = normalization_in 77 | self.normalization_out = normalization_out 78 | self.diffusion_kwargs = diffusion_kwargs 79 | self.sparsification_kwargs = sparsification_kwargs 80 | self.exact = exact 81 | 82 | if self_loop_weight: 83 | assert exact or self_loop_weight == 1 84 | 85 | @torch.no_grad() 86 | def __call__(self, data): 87 | N = data.num_nodes 88 | edge_index = data.edge_index 89 | if data.edge_attr is None: 90 | edge_weight = torch.ones(edge_index.size(1), 91 | device=edge_index.device) 92 | else: 93 | edge_weight = data.edge_attr 94 | assert self.exact 95 | assert edge_weight.dim() == 1 96 | 97 | if self.self_loop_weight: 98 | edge_index, edge_weight = add_self_loops( 99 | edge_index, edge_weight, fill_value=self.self_loop_weight, 100 | num_nodes=N) 101 | 102 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) 103 | 104 | if self.exact: 105 | edge_index, edge_weight = self.transition_matrix( 106 | edge_index, edge_weight, N, self.normalization_in) 107 | diff_mat = self.diffusion_matrix_exact(edge_index, edge_weight, N, 108 | **self.diffusion_kwargs) 109 | edge_index, edge_weight = self.sparsify_dense( 110 | diff_mat, **self.sparsification_kwargs) 111 | else: 112 | edge_index, edge_weight = self.diffusion_matrix_approx( 113 | edge_index, edge_weight, N, self.normalization_in, 114 | **self.diffusion_kwargs) 115 | edge_index, edge_weight = self.sparsify_sparse( 116 | edge_index, edge_weight, N, **self.sparsification_kwargs) 117 | 118 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) 119 | edge_index, edge_weight = self.transition_matrix( 120 | edge_index, edge_weight, N, self.normalization_out) 121 | 122 | data.edge_index = edge_index 123 | data.edge_attr = edge_weight 124 | 125 | return data 126 | 127 | def transition_matrix(self, edge_index, edge_weight, num_nodes, 128 | normalization): 129 | r"""Calculate the approximate, sparse diffusion on a given sparse 130 | matrix. 131 | Args: 132 | edge_index (LongTensor): The edge indices. 133 | edge_weight (Tensor): One-dimensional edge weights. 134 | num_nodes (int): Number of nodes. 135 | normalization (str): Normalization scheme: 136 | 1. :obj:`"sym"`: Symmetric normalization 137 | :math:`\mathbf{T} = \mathbf{D}^{-1/2} \mathbf{A} 138 | \mathbf{D}^{-1/2}`. 139 | 2. :obj:`"col"`: Column-wise normalization 140 | :math:`\mathbf{T} = \mathbf{A} \mathbf{D}^{-1}`. 141 | 3. :obj:`"row"`: Row-wise normalization 142 | :math:`\mathbf{T} = \mathbf{D}^{-1} \mathbf{A}`. 143 | 4. :obj:`None`: No normalization. 144 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 145 | """ 146 | if normalization == 'sym': 147 | row, col = edge_index 148 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 149 | deg_inv_sqrt = deg.pow(-0.5) 150 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 151 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 152 | elif normalization == 'col': 153 | _, col = edge_index 154 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 155 | deg_inv = 1. / deg 156 | deg_inv[deg_inv == float('inf')] = 0 157 | edge_weight = edge_weight * deg_inv[col] 158 | elif normalization == 'row': 159 | row, _ = edge_index 160 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 161 | deg_inv = 1. / deg 162 | deg_inv[deg_inv == float('inf')] = 0 163 | edge_weight = edge_weight * deg_inv[row] 164 | elif normalization is None: 165 | pass 166 | else: 167 | raise ValueError( 168 | 'Transition matrix normalization {} unknown.'.format( 169 | normalization)) 170 | 171 | return edge_index, edge_weight 172 | 173 | def diffusion_matrix_exact(self, edge_index, edge_weight, num_nodes, 174 | method, **kwargs): 175 | r"""Calculate the (dense) diffusion on a given sparse graph. 176 | Note that these exact variants are not scalable. They densify the 177 | adjacency matrix and calculate either its inverse or its matrix 178 | exponential. 179 | Args: 180 | edge_index (LongTensor): The edge indices. 181 | edge_weight (Tensor): One-dimensional edge weights. 182 | num_nodes (int): Number of nodes. 183 | method (str): Diffusion method: 184 | 1. :obj:`"ppr"`: Use personalized PageRank as diffusion. 185 | Additionally expects the parameter: 186 | - **alpha** (*float*) - Return probability in PPR. 187 | Commonly lies in :obj:`[0.05, 0.2]`. 188 | 2. :obj:`"heat"`: Use heat kernel diffusion. 189 | Additionally expects the parameter: 190 | - **t** (*float*) - Time of diffusion. Commonly lies in 191 | :obj:`[2, 10]`. 192 | 3. :obj:`"coeff"`: Freely choose diffusion coefficients. 193 | Additionally expects the parameter: 194 | - **coeffs** (*List[float]*) - List of coefficients 195 | :obj:`theta_k` for each power of the transition matrix 196 | (starting at :obj:`0`). 197 | :rtype: (:class:`Tensor`) 198 | """ 199 | if method == 'ppr': 200 | # α (I_n + (α - 1) A)^-1 201 | edge_weight = (kwargs['alpha'] - 1) * edge_weight 202 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 203 | fill_value=1, 204 | num_nodes=num_nodes) 205 | mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() 206 | diff_matrix = kwargs['alpha'] * torch.inverse(mat) 207 | 208 | elif method == 'heat': 209 | # exp(t (A - I_n)) 210 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 211 | fill_value=-1, 212 | num_nodes=num_nodes) 213 | edge_weight = kwargs['t'] * edge_weight 214 | mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() 215 | undirected = is_undirected(edge_index, edge_weight, num_nodes) 216 | diff_matrix = self.__expm__(mat, undirected) 217 | 218 | elif method == 'coeff': 219 | adj_matrix = to_dense_adj(edge_index, 220 | edge_attr=edge_weight).squeeze() 221 | mat = torch.eye(num_nodes, device=edge_index.device) 222 | 223 | diff_matrix = kwargs['coeffs'][0] * mat 224 | for coeff in kwargs['coeffs'][1:]: 225 | mat = mat @ adj_matrix 226 | diff_matrix += coeff * mat 227 | else: 228 | raise ValueError('Exact GDC diffusion {} unknown.'.format(method)) 229 | 230 | return diff_matrix 231 | 232 | def diffusion_matrix_approx(self, edge_index, edge_weight, num_nodes, 233 | normalization, method, **kwargs): 234 | r"""Calculate the approximate, sparse diffusion on a given sparse 235 | graph. 236 | Args: 237 | edge_index (LongTensor): The edge indices. 238 | edge_weight (Tensor): One-dimensional edge weights. 239 | num_nodes (int): Number of nodes. 240 | normalization (str): Transition matrix normalization scheme 241 | (:obj:`"sym"`, :obj:`"row"`, or :obj:`"col"`). 242 | See :func:`GDC.transition_matrix` for details. 243 | method (str): Diffusion method: 244 | 1. :obj:`"ppr"`: Use personalized PageRank as diffusion. 245 | Additionally expects the parameters: 246 | - **alpha** (*float*) - Return probability in PPR. 247 | Commonly lies in :obj:`[0.05, 0.2]`. 248 | - **eps** (*float*) - Threshold for PPR calculation stopping 249 | criterion (:obj:`edge_weight >= eps * out_degree`). 250 | Recommended default: :obj:`1e-4`. 251 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 252 | """ 253 | if method == 'ppr': 254 | if normalization == 'sym': 255 | # Calculate original degrees. 256 | _, col = edge_index 257 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 258 | 259 | edge_index_np = edge_index.cpu().numpy() 260 | # Assumes coalesced edge_index. 261 | _, indptr, out_degree = np.unique(edge_index_np[0], 262 | return_index=True, 263 | return_counts=True) 264 | 265 | neighbors, neighbor_weights = GDC.__calc_ppr__( 266 | indptr, edge_index_np[1], out_degree, kwargs['alpha'], 267 | kwargs['eps']) 268 | ppr_normalization = 'col' if normalization == 'col' else 'row' 269 | edge_index, edge_weight = self.__neighbors_to_graph__( 270 | neighbors, neighbor_weights, ppr_normalization, 271 | device=edge_index.device) 272 | edge_index = edge_index.to(torch.long) 273 | 274 | if normalization == 'sym': 275 | # We can change the normalization from row-normalized to 276 | # symmetric by multiplying the resulting matrix with D^{1/2} 277 | # from the left and D^{-1/2} from the right. 278 | # Since we use the original degrees for this it will be like 279 | # we had used symmetric normalization from the beginning 280 | # (except for errors due to approximation). 281 | row, col = edge_index 282 | deg_inv = deg.sqrt() 283 | deg_inv_sqrt = deg.pow(-0.5) 284 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 285 | edge_weight = deg_inv[row] * edge_weight * deg_inv_sqrt[col] 286 | elif normalization in ['col', 'row']: 287 | pass 288 | else: 289 | raise ValueError( 290 | ('Transition matrix normalization {} not implemented for ' 291 | 'non-exact GDC computation.').format(normalization)) 292 | 293 | elif method == 'heat': 294 | raise NotImplementedError( 295 | ('Currently no fast heat kernel is implemented. You are ' 296 | 'welcome to create one yourself, e.g., based on ' 297 | '"Kloster and Gleich: Heat kernel based community detection ' 298 | '(KDD 2014)."')) 299 | else: 300 | raise ValueError( 301 | 'Approximate GDC diffusion {} unknown.'.format(method)) 302 | 303 | return edge_index, edge_weight 304 | 305 | def sparsify_dense(self, matrix, method, **kwargs): 306 | r"""Sparsifies the given dense matrix. 307 | Args: 308 | matrix (Tensor): Matrix to sparsify. 309 | num_nodes (int): Number of nodes. 310 | method (str): Method of sparsification. Options: 311 | 1. :obj:`"threshold"`: Remove all edges with weights smaller 312 | than :obj:`eps`. 313 | Additionally expects one of these parameters: 314 | - **eps** (*float*) - Threshold to bound edges at. 315 | - **avg_degree** (*int*) - If :obj:`eps` is not given, 316 | it can optionally be calculated by calculating the 317 | :obj:`eps` required to achieve a given :obj:`avg_degree`. 318 | 2. :obj:`"topk"`: Keep edges with top :obj:`k` edge weights per 319 | node (column). 320 | Additionally expects the following parameters: 321 | - **k** (*int*) - Specifies the number of edges to keep. 322 | - **dim** (*int*) - The axis along which to take the top 323 | :obj:`k`. 324 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 325 | """ 326 | assert matrix.shape[0] == matrix.shape[1] 327 | N = matrix.shape[1] 328 | 329 | if method == 'threshold': 330 | if 'eps' not in kwargs.keys(): 331 | kwargs['eps'] = self.__calculate_eps__(matrix, N, 332 | kwargs['avg_degree']) 333 | 334 | edge_index = torch.nonzero(matrix >= kwargs['eps']).t() 335 | edge_index_flat = edge_index[0] * N + edge_index[1] 336 | edge_weight = matrix.flatten()[edge_index_flat] 337 | 338 | elif method == 'topk': 339 | assert kwargs['dim'] in [0, 1] 340 | sort_idx = torch.argsort(matrix, dim=kwargs['dim'], 341 | descending=True) 342 | if kwargs['dim'] == 0: 343 | top_idx = sort_idx[:kwargs['k']] 344 | edge_weight = torch.gather(matrix, dim=kwargs['dim'], 345 | index=top_idx).flatten() 346 | 347 | row_idx = torch.arange(0, N, device=matrix.device).repeat( 348 | kwargs['k']) 349 | edge_index = torch.stack([top_idx.flatten(), row_idx], dim=0) 350 | else: 351 | top_idx = sort_idx[:, :kwargs['k']] 352 | edge_weight = torch.gather(matrix, dim=kwargs['dim'], 353 | index=top_idx).flatten() 354 | 355 | col_idx = torch.arange( 356 | 0, N, device=matrix.device).repeat_interleave(kwargs['k']) 357 | edge_index = torch.stack([col_idx, top_idx.flatten()], dim=0) 358 | else: 359 | raise ValueError('GDC sparsification {} unknown.'.format(method)) 360 | 361 | return edge_index, edge_weight 362 | 363 | def sparsify_sparse(self, edge_index, edge_weight, num_nodes, method, 364 | **kwargs): 365 | r"""Sparsifies a given sparse graph further. 366 | Args: 367 | edge_index (LongTensor): The edge indices. 368 | edge_weight (Tensor): One-dimensional edge weights. 369 | num_nodes (int): Number of nodes. 370 | method (str): Method of sparsification: 371 | 1. :obj:`"threshold"`: Remove all edges with weights smaller 372 | than :obj:`eps`. 373 | Additionally expects one of these parameters: 374 | - **eps** (*float*) - Threshold to bound edges at. 375 | - **avg_degree** (*int*) - If :obj:`eps` is not given, 376 | it can optionally be calculated by calculating the 377 | :obj:`eps` required to achieve a given :obj:`avg_degree`. 378 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 379 | """ 380 | if method == 'threshold': 381 | if 'eps' not in kwargs.keys(): 382 | kwargs['eps'] = self.__calculate_eps__(edge_weight, num_nodes, 383 | kwargs['avg_degree']) 384 | 385 | remaining_edge_idx = torch.nonzero( 386 | edge_weight >= kwargs['eps']).flatten() 387 | edge_index = edge_index[:, remaining_edge_idx] 388 | edge_weight = edge_weight[remaining_edge_idx] 389 | elif method == 'topk': 390 | raise NotImplementedError( 391 | 'Sparse topk sparsification not implemented.') 392 | else: 393 | raise ValueError('GDC sparsification {} unknown.'.format(method)) 394 | 395 | return edge_index, edge_weight 396 | 397 | def __expm__(self, matrix, symmetric): 398 | r"""Calculates matrix exponential. 399 | Args: 400 | matrix (Tensor): Matrix to take exponential of. 401 | symmetric (bool): Specifies whether the matrix is symmetric. 402 | :rtype: (:class:`Tensor`) 403 | """ 404 | if symmetric: 405 | e, V = torch.symeig(matrix, eigenvectors=True) 406 | diff_mat = V @ torch.diag(e.exp()) @ V.t() 407 | else: 408 | diff_mat_np = expm(matrix.cpu().numpy()) 409 | diff_mat = torch.Tensor(diff_mat_np).to(matrix.device) 410 | return diff_mat 411 | 412 | def __calculate_eps__(self, matrix, num_nodes, avg_degree): 413 | r"""Calculates threshold necessary to achieve a given average degree. 414 | Args: 415 | matrix (Tensor): Adjacency matrix or edge weights. 416 | num_nodes (int): Number of nodes. 417 | avg_degree (int): Target average degree. 418 | :rtype: (:class:`float`) 419 | """ 420 | sorted_edges = torch.sort(matrix.flatten(), descending=True).values 421 | if avg_degree * num_nodes > len(sorted_edges): 422 | return -np.inf 423 | return sorted_edges[avg_degree * num_nodes - 1] 424 | 425 | def __neighbors_to_graph__(self, neighbors, neighbor_weights, 426 | normalization='row', device='cpu'): 427 | r"""Combine a list of neighbors and neighbor weights to create a sparse 428 | graph. 429 | Args: 430 | neighbors (List[List[int]]): List of neighbors for each node. 431 | neighbor_weights (List[List[float]]): List of weights for the 432 | neighbors of each node. 433 | normalization (str): Normalization of resulting matrix 434 | (options: :obj:`"row"`, :obj:`"col"`). (default: :obj:`"row"`) 435 | device (torch.device): Device to create output tensors on. 436 | (default: :obj:`"cpu"`) 437 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 438 | """ 439 | edge_weight = torch.Tensor(np.concatenate(neighbor_weights)).to(device) 440 | i = np.repeat(np.arange(len(neighbors)), 441 | np.fromiter(map(len, neighbors), dtype=np.int)) 442 | j = np.concatenate(neighbors) 443 | if normalization == 'col': 444 | edge_index = torch.Tensor(np.vstack([j, i])).to(device) 445 | N = len(neighbors) 446 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) 447 | elif normalization == 'row': 448 | edge_index = torch.Tensor(np.vstack([i, j])).to(device) 449 | else: 450 | raise ValueError( 451 | f"PPR matrix normalization {normalization} unknown.") 452 | return edge_index, edge_weight 453 | 454 | @staticmethod 455 | @jit() 456 | def __calc_ppr__(indptr, indices, out_degree, alpha, eps): 457 | r"""Calculate the personalized PageRank vector for all nodes 458 | using a variant of the Andersen algorithm 459 | (see Andersen et al. :Local Graph Partitioning using PageRank Vectors.) 460 | Args: 461 | indptr (np.ndarray): Index pointer for the sparse matrix 462 | (CSR-format). 463 | indices (np.ndarray): Indices of the sparse matrix entries 464 | (CSR-format). 465 | out_degree (np.ndarray): Out-degree of each node. 466 | alpha (float): Alpha of the PageRank to calculate. 467 | eps (float): Threshold for PPR calculation stopping criterion 468 | (:obj:`edge_weight >= eps * out_degree`). 469 | :rtype: (:class:`List[List[int]]`, :class:`List[List[float]]`) 470 | """ 471 | alpha_eps = alpha * eps 472 | js = [] 473 | vals = [] 474 | for inode in range(len(out_degree)): 475 | p = {inode: 0.0} 476 | r = {} 477 | r[inode] = alpha 478 | q = [inode] 479 | while len(q) > 0: 480 | unode = q.pop() 481 | 482 | res = r[unode] if unode in r else 0 483 | if unode in p: 484 | p[unode] += res 485 | else: 486 | p[unode] = res 487 | r[unode] = 0 488 | for vnode in indices[indptr[unode]:indptr[unode + 1]]: 489 | _val = (1 - alpha) * res / out_degree[unode] 490 | if vnode in r: 491 | r[vnode] += _val 492 | else: 493 | r[vnode] = _val 494 | 495 | res_vnode = r[vnode] if vnode in r else 0 496 | if res_vnode >= alpha_eps * out_degree[vnode]: 497 | if vnode not in q: 498 | q.append(vnode) 499 | js.append(list(p.keys())) 500 | vals.append(list(p.values())) 501 | return js, vals 502 | 503 | def __repr__(self): 504 | return '{}()'.format(self.__class__.__name__) --------------------------------------------------------------------------------