├── .idea ├── .gitignore ├── GNN_biomarker_MEDIA.iml ├── deployment.xml ├── encodings.xml ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml └── webServers.xml ├── 01-fetch_data.py ├── 02-process_data.py ├── 03-main.py ├── README.md ├── imports ├── ABIDEDataset.py ├── __inits__.py ├── __pycache__ │ ├── preprocess_data.cpython-36.pyc │ └── preprocess_data.cpython-38.pyc ├── gdc.py ├── preprocess_data.py ├── read_abide_stats_parall.py └── utils.py ├── net ├── braingnn.py ├── braingraphconv.py ├── brainmsgpassing.py └── inits.py └── requirements.txt /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/GNN_biomarker_MEDIA.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 213 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | -------------------------------------------------------------------------------- /01-fetch_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Mwiza Kunda 2 | # Copyright (C) 2017 Sarah Parisot , , Sofia Ira Ktena 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | from nilearn import datasets 18 | import argparse 19 | from imports import preprocess_data as Reader 20 | import os 21 | import shutil 22 | import sys 23 | 24 | # Input data variables 25 | code_folder = os.getcwd() 26 | root_folder = '/home/azureuser/projects/BrainGNN/data/' 27 | data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/') 28 | if not os.path.exists(data_folder): 29 | os.makedirs(data_folder) 30 | shutil.copyfile(os.path.join(root_folder,'subject_ID.txt'), os.path.join(data_folder, 'subject_IDs.txt')) 31 | 32 | def str2bool(v): 33 | if isinstance(v, bool): 34 | return v 35 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 36 | return True 37 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 38 | return False 39 | else: 40 | raise argparse.ArgumentTypeError('Boolean value expected.') 41 | 42 | 43 | def main(): 44 | parser = argparse.ArgumentParser(description='Download ABIDE data and compute functional connectivity matrices') 45 | parser.add_argument('--pipeline', default='cpac', type=str, 46 | help='Pipeline to preprocess ABIDE data. Available options are ccs, cpac, dparsf and niak.' 47 | ' default: cpac.') 48 | parser.add_argument('--atlas', default='cc200', 49 | help='Brain parcellation atlas. Options: ho, cc200 and cc400, default: cc200.') 50 | parser.add_argument('--download', default=True, type=str2bool, 51 | help='Dowload data or just compute functional connectivity. default: True') 52 | args = parser.parse_args() 53 | print(args) 54 | 55 | params = dict() 56 | 57 | pipeline = args.pipeline 58 | atlas = args.atlas 59 | download = args.download 60 | 61 | # Files to fetch 62 | 63 | files = ['rois_' + atlas] 64 | 65 | filemapping = {'func_preproc': 'func_preproc.nii.gz', 66 | files[0]: files[0] + '.1D'} 67 | 68 | 69 | # Download database files 70 | if download == True: 71 | abide = datasets.fetch_abide_pcp(data_dir=root_folder, pipeline=pipeline, 72 | band_pass_filtering=True, global_signal_regression=False, derivatives=files, 73 | quality_checked=False) 74 | 75 | subject_IDs = Reader.get_ids() #changed path to data path 76 | subject_IDs = subject_IDs.tolist() 77 | 78 | # Create a folder for each subject 79 | for s, fname in zip(subject_IDs, Reader.fetch_filenames(subject_IDs, files[0], atlas)): 80 | subject_folder = os.path.join(data_folder, s) 81 | if not os.path.exists(subject_folder): 82 | os.mkdir(subject_folder) 83 | 84 | # Get the base filename for each subject 85 | base = fname.split(files[0])[0] 86 | 87 | # Move each subject file to the subject folder 88 | for fl in files: 89 | if not os.path.exists(os.path.join(subject_folder, base + filemapping[fl])): 90 | shutil.move(base + filemapping[fl], subject_folder) 91 | 92 | time_series = Reader.get_timeseries(subject_IDs, atlas) 93 | 94 | # Compute and save connectivity matrices 95 | Reader.subject_connectivity(time_series, subject_IDs, atlas, 'correlation') 96 | Reader.subject_connectivity(time_series, subject_IDs, atlas, 'partial correlation') 97 | 98 | 99 | if __name__ == '__main__': 100 | main() -------------------------------------------------------------------------------- /02-process_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Mwiza Kunda 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | 17 | import sys 18 | import argparse 19 | import pandas as pd 20 | import numpy as np 21 | from imports import preprocess_data as Reader 22 | import deepdish as dd 23 | import warnings 24 | import os 25 | 26 | warnings.filterwarnings("ignore") 27 | root_folder = '/home/azureuser/projects/BrainGNN/data/' 28 | data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/') 29 | 30 | # Process boolean command line arguments 31 | def str2bool(v): 32 | if isinstance(v, bool): 33 | return v 34 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 35 | return True 36 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 37 | return False 38 | else: 39 | raise argparse.ArgumentTypeError('Boolean value expected.') 40 | 41 | 42 | def main(): 43 | parser = argparse.ArgumentParser(description='Classification of the ABIDE dataset using a Ridge classifier. ' 44 | 'MIDA is used to minimize the distribution mismatch between ABIDE sites') 45 | parser.add_argument('--atlas', default='cc200', 46 | help='Atlas for network construction (node definition) options: ho, cc200, cc400, default: cc200.') 47 | parser.add_argument('--seed', default=123, type=int, help='Seed for random initialisation. default: 1234.') 48 | parser.add_argument('--nclass', default=2, type=int, help='Number of classes. default:2') 49 | 50 | 51 | args = parser.parse_args() 52 | print('Arguments: \n', args) 53 | 54 | 55 | params = dict() 56 | 57 | params['seed'] = args.seed # seed for random initialisation 58 | 59 | # Algorithm choice 60 | params['atlas'] = args.atlas # Atlas for network construction 61 | atlas = args.atlas # Atlas for network construction (node definition) 62 | 63 | # Get subject IDs and class labels 64 | subject_IDs = Reader.get_ids() 65 | labels = Reader.get_subject_score(subject_IDs, score='DX_GROUP') 66 | 67 | # Number of subjects and classes for binary classification 68 | num_classes = args.nclass 69 | num_subjects = len(subject_IDs) 70 | params['n_subjects'] = num_subjects 71 | 72 | # Initialise variables for class labels and acquisition sites 73 | # 1 is autism, 2 is control 74 | y_data = np.zeros([num_subjects, num_classes]) # n x 2 75 | y = np.zeros([num_subjects, 1]) # n x 1 76 | 77 | # Get class labels for all subjects 78 | for i in range(num_subjects): 79 | y_data[i, int(labels[subject_IDs[i]]) - 1] = 1 80 | y[i] = int(labels[subject_IDs[i]]) 81 | 82 | # Compute feature vectors (vectorised connectivity networks) 83 | fea_corr = Reader.get_networks(subject_IDs, iter_no='', kind='correlation', atlas_name=atlas) #(1035, 200, 200) 84 | fea_pcorr = Reader.get_networks(subject_IDs, iter_no='', kind='partial correlation', atlas_name=atlas) #(1035, 200, 200) 85 | 86 | if not os.path.exists(os.path.join(data_folder,'raw')): 87 | os.makedirs(os.path.join(data_folder,'raw')) 88 | for i, subject in enumerate(subject_IDs): 89 | dd.io.save(os.path.join(data_folder,'raw',subject+'.h5'),{'corr':fea_corr[i],'pcorr':fea_pcorr[i],'label':y[i]%2}) 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /03-main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import time 5 | import copy 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.optim import lr_scheduler 10 | from tensorboardX import SummaryWriter 11 | 12 | from imports.ABIDEDataset import ABIDEDataset 13 | from torch_geometric.data import DataLoader 14 | from net.braingnn import Network 15 | from imports.utils import train_val_test_split 16 | from sklearn.metrics import classification_report, confusion_matrix 17 | 18 | torch.manual_seed(123) 19 | 20 | EPS = 1e-10 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--epoch', type=int, default=0, help='starting epoch') 26 | parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training') 27 | parser.add_argument('--batchSize', type=int, default=100, help='size of the batches') 28 | parser.add_argument('--dataroot', type=str, default='/home/azureuser/projects/BrainGNN/data/ABIDE_pcp/cpac/filt_noglobal', help='root directory of the dataset') 29 | parser.add_argument('--fold', type=int, default=0, help='training which fold') 30 | parser.add_argument('--lr', type = float, default=0.01, help='learning rate') 31 | parser.add_argument('--stepsize', type=int, default=20, help='scheduler step size') 32 | parser.add_argument('--gamma', type=float, default=0.5, help='scheduler shrinking rate') 33 | parser.add_argument('--weightdecay', type=float, default=5e-3, help='regularization') 34 | parser.add_argument('--lamb0', type=float, default=1, help='classification loss weight') 35 | parser.add_argument('--lamb1', type=float, default=0, help='s1 unit regularization') 36 | parser.add_argument('--lamb2', type=float, default=0, help='s2 unit regularization') 37 | parser.add_argument('--lamb3', type=float, default=0.1, help='s1 entropy regularization') 38 | parser.add_argument('--lamb4', type=float, default=0.1, help='s2 entropy regularization') 39 | parser.add_argument('--lamb5', type=float, default=0, help='s1 consistence regularization') 40 | parser.add_argument('--layer', type=int, default=2, help='number of GNN layers') 41 | parser.add_argument('--ratio', type=float, default=0.5, help='pooling ratio') 42 | parser.add_argument('--indim', type=int, default=200, help='feature dim') 43 | parser.add_argument('--nroi', type=int, default=200, help='num of ROIs') 44 | parser.add_argument('--nclass', type=int, default=2, help='num of classes') 45 | parser.add_argument('--load_model', type=bool, default=False) 46 | parser.add_argument('--save_model', type=bool, default=True) 47 | parser.add_argument('--optim', type=str, default='SGD', help='optimization method: SGD, Adam') 48 | parser.add_argument('--save_path', type=str, default='./model/', help='path to save model') 49 | opt = parser.parse_args() 50 | 51 | if not os.path.exists(opt.save_path): 52 | os.makedirs(opt.save_path) 53 | 54 | #################### Parameter Initialization ####################### 55 | path = opt.dataroot 56 | name = 'ABIDE' 57 | save_model = opt.save_model 58 | load_model = opt.load_model 59 | opt_method = opt.optim 60 | num_epoch = opt.n_epochs 61 | fold = opt.fold 62 | writer = SummaryWriter(os.path.join('./log',str(fold))) 63 | 64 | 65 | 66 | ################## Define Dataloader ################################## 67 | 68 | dataset = ABIDEDataset(path,name) 69 | dataset.data.y = dataset.data.y.squeeze() 70 | dataset.data.x[dataset.data.x == float('inf')] = 0 71 | 72 | tr_index,val_index,te_index = train_val_test_split(fold=fold) 73 | train_mask = torch.zeros(len(dataset), dtype=torch.uint8) 74 | val_mask = torch.zeros(len(dataset), dtype=torch.uint8) 75 | test_mask = torch.zeros(len(dataset), dtype=torch.uint8) 76 | train_mask[tr_index] = 1 77 | val_mask[val_index] = 1 78 | test_mask[te_index] = 1 79 | train_dataset = dataset[train_mask] 80 | val_dataset = dataset[val_mask] 81 | test_dataset = dataset[test_mask] 82 | 83 | 84 | train_loader = DataLoader(train_dataset,batch_size=opt.batchSize, shuffle= True) 85 | val_loader = DataLoader(val_dataset, batch_size=opt.batchSize, shuffle=False) 86 | test_loader = DataLoader(test_dataset, batch_size=opt.batchSize, shuffle=False) 87 | 88 | 89 | 90 | ############### Define Graph Deep Learning Network ########################## 91 | model = Network(opt.indim,opt.ratio,opt.nclass).to(device) 92 | print(model) 93 | 94 | if opt_method == 'Adam': 95 | optimizer = torch.optim.Adam(model.parameters(), lr= opt.lr, weight_decay=opt.weightdecay) 96 | elif opt_method == 'SGD': 97 | optimizer = torch.optim.SGD(model.parameters(), lr =opt.lr, momentum = 0.9, weight_decay=opt.weightdecay, nesterov = True) 98 | 99 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.stepsize, gamma=opt.gamma) 100 | 101 | ############################### Define Other Loss Functions ######################################## 102 | def topk_loss(s,ratio): 103 | if ratio > 0.5: 104 | ratio = 1-ratio 105 | s = s.sort(dim=1).values 106 | res = -torch.log(s[:,-int(s.size(1)*ratio):]+EPS).mean() -torch.log(1-s[:,:int(s.size(1)*ratio)]+EPS).mean() 107 | return res 108 | 109 | 110 | def consist_loss(s): 111 | if len(s) == 0: 112 | return 0 113 | s = torch.sigmoid(s) 114 | W = torch.ones(s.shape[0],s.shape[0]) 115 | D = torch.eye(s.shape[0])*torch.sum(W,dim=1) 116 | L = D-W 117 | L = L.to(device) 118 | res = torch.trace(torch.transpose(s,0,1) @ L @ s)/(s.shape[0]*s.shape[0]) 119 | return res 120 | 121 | ###################### Network Training Function##################################### 122 | def train(epoch): 123 | print('train...........') 124 | scheduler.step() 125 | 126 | for param_group in optimizer.param_groups: 127 | print("LR", param_group['lr']) 128 | model.train() 129 | s1_list = [] 130 | s2_list = [] 131 | loss_all = 0 132 | step = 0 133 | for data in train_loader: 134 | data = data.to(device) 135 | optimizer.zero_grad() 136 | output, w1, w2, s1, s2 = model(data.x, data.edge_index, data.batch, data.edge_attr, data.pos) 137 | s1_list.append(s1.view(-1).detach().cpu().numpy()) 138 | s2_list.append(s2.view(-1).detach().cpu().numpy()) 139 | 140 | loss_c = F.nll_loss(output, data.y) 141 | 142 | loss_p1 = (torch.norm(w1, p=2)-1) ** 2 143 | loss_p2 = (torch.norm(w2, p=2)-1) ** 2 144 | loss_tpk1 = topk_loss(s1,opt.ratio) 145 | loss_tpk2 = topk_loss(s2,opt.ratio) 146 | loss_consist = 0 147 | for c in range(opt.nclass): 148 | loss_consist += consist_loss(s1[data.y == c]) 149 | loss = opt.lamb0*loss_c + opt.lamb1 * loss_p1 + opt.lamb2 * loss_p2 \ 150 | + opt.lamb3 * loss_tpk1 + opt.lamb4 *loss_tpk2 + opt.lamb5* loss_consist 151 | writer.add_scalar('train/classification_loss', loss_c, epoch*len(train_loader)+step) 152 | writer.add_scalar('train/unit_loss1', loss_p1, epoch*len(train_loader)+step) 153 | writer.add_scalar('train/unit_loss2', loss_p2, epoch*len(train_loader)+step) 154 | writer.add_scalar('train/TopK_loss1', loss_tpk1, epoch*len(train_loader)+step) 155 | writer.add_scalar('train/TopK_loss2', loss_tpk2, epoch*len(train_loader)+step) 156 | writer.add_scalar('train/GCL_loss', loss_consist, epoch*len(train_loader)+step) 157 | step = step + 1 158 | 159 | loss.backward() 160 | loss_all += loss.item() * data.num_graphs 161 | optimizer.step() 162 | 163 | s1_arr = np.hstack(s1_list) 164 | s2_arr = np.hstack(s2_list) 165 | return loss_all / len(train_dataset), s1_arr, s2_arr ,w1,w2 166 | 167 | 168 | ###################### Network Testing Function##################################### 169 | def test_acc(loader): 170 | model.eval() 171 | correct = 0 172 | for data in loader: 173 | data = data.to(device) 174 | outputs= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos) 175 | pred = outputs[0].max(dim=1)[1] 176 | correct += pred.eq(data.y).sum().item() 177 | 178 | return correct / len(loader.dataset) 179 | 180 | def test_loss(loader,epoch): 181 | print('testing...........') 182 | model.eval() 183 | loss_all = 0 184 | for data in loader: 185 | data = data.to(device) 186 | output, w1, w2, s1, s2= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos) 187 | loss_c = F.nll_loss(output, data.y) 188 | 189 | loss_p1 = (torch.norm(w1, p=2)-1) ** 2 190 | loss_p2 = (torch.norm(w2, p=2)-1) ** 2 191 | loss_tpk1 = topk_loss(s1,opt.ratio) 192 | loss_tpk2 = topk_loss(s2,opt.ratio) 193 | loss_consist = 0 194 | for c in range(opt.nclass): 195 | loss_consist += consist_loss(s1[data.y == c]) 196 | loss = opt.lamb0*loss_c + opt.lamb1 * loss_p1 + opt.lamb2 * loss_p2 \ 197 | + opt.lamb3 * loss_tpk1 + opt.lamb4 *loss_tpk2 + opt.lamb5* loss_consist 198 | 199 | loss_all += loss.item() * data.num_graphs 200 | return loss_all / len(loader.dataset) 201 | 202 | ####################################################################################### 203 | ############################ Model Training ######################################### 204 | ####################################################################################### 205 | best_model_wts = copy.deepcopy(model.state_dict()) 206 | best_loss = 1e10 207 | for epoch in range(0, num_epoch): 208 | since = time.time() 209 | tr_loss, s1_arr, s2_arr, w1, w2 = train(epoch) 210 | tr_acc = test_acc(train_loader) 211 | val_acc = test_acc(val_loader) 212 | val_loss = test_loss(val_loader,epoch) 213 | time_elapsed = time.time() - since 214 | print('*====**') 215 | print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 216 | print('Epoch: {:03d}, Train Loss: {:.7f}, ' 217 | 'Train Acc: {:.7f}, Test Loss: {:.7f}, Test Acc: {:.7f}'.format(epoch, tr_loss, 218 | tr_acc, val_loss, val_acc)) 219 | 220 | writer.add_scalars('Acc',{'train_acc':tr_acc,'val_acc':val_acc}, epoch) 221 | writer.add_scalars('Loss', {'train_loss': tr_loss, 'val_loss': val_loss}, epoch) 222 | writer.add_histogram('Hist/hist_s1', s1_arr, epoch) 223 | writer.add_histogram('Hist/hist_s2', s2_arr, epoch) 224 | 225 | if val_loss < best_loss and epoch > 5: 226 | print("saving best model") 227 | best_loss = val_loss 228 | best_model_wts = copy.deepcopy(model.state_dict()) 229 | if save_model: 230 | torch.save(best_model_wts, os.path.join(opt.save_path,str(fold)+'.pth')) 231 | 232 | ####################################################################################### 233 | ######################### Testing on testing set ###################################### 234 | ####################################################################################### 235 | 236 | if opt.load_model: 237 | model = Network(opt.indim,opt.ratio,opt.nclass).to(device) 238 | model.load_state_dict(torch.load(os.path.join(opt.save_path,str(fold)+'.pth'))) 239 | model.eval() 240 | preds = [] 241 | correct = 0 242 | for data in val_loader: 243 | data = data.to(device) 244 | outputs= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos) 245 | pred = outputs[0].max(1)[1] 246 | preds.append(pred.cpu().detach().numpy()) 247 | correct += pred.eq(data.y).sum().item() 248 | preds = np.concatenate(preds,axis=0) 249 | trues = val_dataset.data.y.cpu().detach().numpy() 250 | cm = confusion_matrix(trues,preds) 251 | print("Confusion matrix") 252 | print(classification_report(trues, preds)) 253 | 254 | else: 255 | model.load_state_dict(best_model_wts) 256 | model.eval() 257 | test_accuracy = test_acc(test_loader) 258 | test_l= test_loss(test_loader,0) 259 | print("===========================") 260 | print("Test Acc: {:.7f}, Test Loss: {:.7f} ".format(test_accuracy, test_l)) 261 | print(opt) 262 | 263 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Neural Network for Brain Network Analysis 2 | A preliminary implementation of BrainGNN 3 | 4 | 5 | ## Usage 6 | ### Setup 7 | **pip** 8 | 9 | See the `requirements.txt` for environment configuration. 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | **PYG** 14 | 15 | To install pyg library, [please refer to the document](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) 16 | 17 | ### Dataset 18 | **ABIDE** 19 | 20 | We treat each fMRI as a brain graph. How to download and construct the graphs? 21 | ``` 22 | python 01-fetch_data.py 23 | python 02-process_data.py 24 | ``` 25 | 26 | ### How to run classification? 27 | Training and testing are integrated in file `main.py`. To run 28 | ``` 29 | python 03-main.py 30 | ``` 31 | 32 | 33 | ## Citation 34 | If you find the code and dataset useful, please cite our paper. 35 | ```latex 36 | @article{li2020braingnn, 37 | title={Braingnn: Interpretable brain graph neural network for fmri analysis}, 38 | author={Li, Xiaoxiao and Zhou,Yuan and Dvornek, Nicha and Zhang, Muhan and Gao, Siyuan and Zhuang, Juntang and Scheinost, Dustin and Staib, Lawrence and Ventola, Pamela and Duncan, James}, 39 | journal={bioRxiv}, 40 | year={2020}, 41 | publisher={Cold Spring Harbor Laboratory} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /imports/ABIDEDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import InMemoryDataset,Data 3 | from os.path import join, isfile 4 | from os import listdir 5 | import numpy as np 6 | import os.path as osp 7 | from imports.read_abide_stats_parall import read_data 8 | 9 | 10 | class ABIDEDataset(InMemoryDataset): 11 | def __init__(self, root, name, transform=None, pre_transform=None): 12 | self.root = root 13 | self.name = name 14 | super(ABIDEDataset, self).__init__(root,transform, pre_transform) 15 | self.data, self.slices = torch.load(self.processed_paths[0]) 16 | 17 | @property 18 | def raw_file_names(self): 19 | data_dir = osp.join(self.root,'raw') 20 | onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))] 21 | onlyfiles.sort() 22 | return onlyfiles 23 | @property 24 | def processed_file_names(self): 25 | return 'data.pt' 26 | 27 | def download(self): 28 | # Download to `self.raw_dir`. 29 | return 30 | 31 | def process(self): 32 | # Read data into huge `Data` list. 33 | self.data, self.slices = read_data(self.raw_dir) 34 | 35 | if self.pre_filter is not None: 36 | data_list = [self.get(idx) for idx in range(len(self))] 37 | data_list = [data for data in data_list if self.pre_filter(data)] 38 | self.data, self.slices = self.collate(data_list) 39 | 40 | if self.pre_transform is not None: 41 | data_list = [self.get(idx) for idx in range(len(self))] 42 | data_list = [self.pre_transform(data) for data in data_list] 43 | self.data, self.slices = self.collate(data_list) 44 | 45 | torch.save((self.data, self.slices), self.processed_paths[0]) 46 | 47 | def __repr__(self): 48 | return '{}({})'.format(self.name, len(self)) 49 | -------------------------------------------------------------------------------- /imports/__inits__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LifangHe/BrainGNN_Pytorch/30b78ae28a1e8d6d23004884b6c8e7010bcaf587/imports/__inits__.py -------------------------------------------------------------------------------- /imports/__pycache__/preprocess_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LifangHe/BrainGNN_Pytorch/30b78ae28a1e8d6d23004884b6c8e7010bcaf587/imports/__pycache__/preprocess_data.cpython-36.pyc -------------------------------------------------------------------------------- /imports/__pycache__/preprocess_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LifangHe/BrainGNN_Pytorch/30b78ae28a1e8d6d23004884b6c8e7010bcaf587/imports/__pycache__/preprocess_data.cpython-38.pyc -------------------------------------------------------------------------------- /imports/gdc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numba 3 | import numpy as np 4 | from scipy.linalg import expm 5 | from torch_geometric.utils import add_self_loops, is_undirected, to_dense_adj 6 | from torch_sparse import coalesce 7 | from torch_scatter import scatter_add 8 | 9 | 10 | def jit(): 11 | def decorator(func): 12 | try: 13 | return numba.jit(cache=True)(func) 14 | except RuntimeError: 15 | return numba.jit(cache=False)(func) 16 | 17 | return decorator 18 | 19 | 20 | class GDC(object): 21 | r"""Processes the graph via Graph Diffusion Convolution (GDC) from the 22 | `"Diffusion Improves Graph Learning" `_ 23 | paper. 24 | .. note:: 25 | The paper offers additional advice on how to choose the 26 | hyperparameters. 27 | For an example of using GCN with GDC, see `examples/gcn.py 28 | `_. 30 | Args: 31 | self_loop_weight (float, optional): Weight of the added self-loop. 32 | Set to :obj:`None` to add no self-loops. (default: :obj:`1`) 33 | normalization_in (str, optional): Normalization of the transition 34 | matrix on the original (input) graph. Possible values: 35 | :obj:`"sym"`, :obj:`"col"`, and :obj:`"row"`. 36 | See :func:`GDC.transition_matrix` for details. 37 | (default: :obj:`"sym"`) 38 | normalization_out (str, optional): Normalization of the transition 39 | matrix on the transformed GDC (output) graph. Possible values: 40 | :obj:`"sym"`, :obj:`"col"`, :obj:`"row"`, and :obj:`None`. 41 | See :func:`GDC.transition_matrix` for details. 42 | (default: :obj:`"col"`) 43 | diffusion_kwargs (dict, optional): Dictionary containing the parameters 44 | for diffusion. 45 | `method` specifies the diffusion method (:obj:`"ppr"`, 46 | :obj:`"heat"` or :obj:`"coeff"`). 47 | Each diffusion method requires different additional parameters. 48 | See :func:`GDC.diffusion_matrix_exact` or 49 | :func:`GDC.diffusion_matrix_approx` for details. 50 | (default: :obj:`dict(method='ppr', alpha=0.15)`) 51 | sparsification_kwargs (dict, optional): Dictionary containing the 52 | parameters for sparsification. 53 | `method` specifies the sparsification method (:obj:`"threshold"` or 54 | :obj:`"topk"`). 55 | Each sparsification method requires different additional 56 | parameters. 57 | See :func:`GDC.sparsify_dense` for details. 58 | (default: :obj:`dict(method='threshold', avg_degree=64)`) 59 | exact (bool, optional): Whether to exactly calculate the diffusion 60 | matrix. 61 | Note that the exact variants are not scalable. 62 | They densify the adjacency matrix and calculate either its inverse 63 | or its matrix exponential. 64 | However, the approximate variants do not support edge weights and 65 | currently only personalized PageRank and sparsification by 66 | threshold are implemented as fast, approximate versions. 67 | (default: :obj:`True`) 68 | :rtype: :class:`torch_geometric.data.Data` 69 | """ 70 | def __init__(self, self_loop_weight=1, normalization_in='sym', 71 | normalization_out='col', 72 | diffusion_kwargs=dict(method='ppr', alpha=0.15), 73 | sparsification_kwargs=dict(method='threshold', 74 | avg_degree=64), exact=True): 75 | self.self_loop_weight = self_loop_weight 76 | self.normalization_in = normalization_in 77 | self.normalization_out = normalization_out 78 | self.diffusion_kwargs = diffusion_kwargs 79 | self.sparsification_kwargs = sparsification_kwargs 80 | self.exact = exact 81 | 82 | if self_loop_weight: 83 | assert exact or self_loop_weight == 1 84 | 85 | @torch.no_grad() 86 | def __call__(self, data): 87 | N = data.num_nodes 88 | edge_index = data.edge_index 89 | if data.edge_attr is None: 90 | edge_weight = torch.ones(edge_index.size(1), 91 | device=edge_index.device) 92 | else: 93 | edge_weight = data.edge_attr 94 | assert self.exact 95 | assert edge_weight.dim() == 1 96 | 97 | if self.self_loop_weight: 98 | edge_index, edge_weight = add_self_loops( 99 | edge_index, edge_weight, fill_value=self.self_loop_weight, 100 | num_nodes=N) 101 | 102 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) 103 | 104 | if self.exact: 105 | edge_index, edge_weight = self.transition_matrix( 106 | edge_index, edge_weight, N, self.normalization_in) 107 | diff_mat = self.diffusion_matrix_exact(edge_index, edge_weight, N, 108 | **self.diffusion_kwargs) 109 | edge_index, edge_weight = self.sparsify_dense( 110 | diff_mat, **self.sparsification_kwargs) 111 | else: 112 | edge_index, edge_weight = self.diffusion_matrix_approx( 113 | edge_index, edge_weight, N, self.normalization_in, 114 | **self.diffusion_kwargs) 115 | edge_index, edge_weight = self.sparsify_sparse( 116 | edge_index, edge_weight, N, **self.sparsification_kwargs) 117 | 118 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) 119 | edge_index, edge_weight = self.transition_matrix( 120 | edge_index, edge_weight, N, self.normalization_out) 121 | 122 | data.edge_index = edge_index 123 | data.edge_attr = edge_weight 124 | 125 | return data 126 | 127 | def transition_matrix(self, edge_index, edge_weight, num_nodes, 128 | normalization): 129 | r"""Calculate the approximate, sparse diffusion on a given sparse 130 | matrix. 131 | Args: 132 | edge_index (LongTensor): The edge indices. 133 | edge_weight (Tensor): One-dimensional edge weights. 134 | num_nodes (int): Number of nodes. 135 | normalization (str): Normalization scheme: 136 | 1. :obj:`"sym"`: Symmetric normalization 137 | :math:`\mathbf{T} = \mathbf{D}^{-1/2} \mathbf{A} 138 | \mathbf{D}^{-1/2}`. 139 | 2. :obj:`"col"`: Column-wise normalization 140 | :math:`\mathbf{T} = \mathbf{A} \mathbf{D}^{-1}`. 141 | 3. :obj:`"row"`: Row-wise normalization 142 | :math:`\mathbf{T} = \mathbf{D}^{-1} \mathbf{A}`. 143 | 4. :obj:`None`: No normalization. 144 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 145 | """ 146 | if normalization == 'sym': 147 | row, col = edge_index 148 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 149 | deg_inv_sqrt = deg.pow(-0.5) 150 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 151 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 152 | elif normalization == 'col': 153 | _, col = edge_index 154 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 155 | deg_inv = 1. / deg 156 | deg_inv[deg_inv == float('inf')] = 0 157 | edge_weight = edge_weight * deg_inv[col] 158 | elif normalization == 'row': 159 | row, _ = edge_index 160 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 161 | deg_inv = 1. / deg 162 | deg_inv[deg_inv == float('inf')] = 0 163 | edge_weight = edge_weight * deg_inv[row] 164 | elif normalization is None: 165 | pass 166 | else: 167 | raise ValueError( 168 | 'Transition matrix normalization {} unknown.'.format( 169 | normalization)) 170 | 171 | return edge_index, edge_weight 172 | 173 | def diffusion_matrix_exact(self, edge_index, edge_weight, num_nodes, 174 | method, **kwargs): 175 | r"""Calculate the (dense) diffusion on a given sparse graph. 176 | Note that these exact variants are not scalable. They densify the 177 | adjacency matrix and calculate either its inverse or its matrix 178 | exponential. 179 | Args: 180 | edge_index (LongTensor): The edge indices. 181 | edge_weight (Tensor): One-dimensional edge weights. 182 | num_nodes (int): Number of nodes. 183 | method (str): Diffusion method: 184 | 1. :obj:`"ppr"`: Use personalized PageRank as diffusion. 185 | Additionally expects the parameter: 186 | - **alpha** (*float*) - Return probability in PPR. 187 | Commonly lies in :obj:`[0.05, 0.2]`. 188 | 2. :obj:`"heat"`: Use heat kernel diffusion. 189 | Additionally expects the parameter: 190 | - **t** (*float*) - Time of diffusion. Commonly lies in 191 | :obj:`[2, 10]`. 192 | 3. :obj:`"coeff"`: Freely choose diffusion coefficients. 193 | Additionally expects the parameter: 194 | - **coeffs** (*List[float]*) - List of coefficients 195 | :obj:`theta_k` for each power of the transition matrix 196 | (starting at :obj:`0`). 197 | :rtype: (:class:`Tensor`) 198 | """ 199 | if method == 'ppr': 200 | # α (I_n + (α - 1) A)^-1 201 | edge_weight = (kwargs['alpha'] - 1) * edge_weight 202 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 203 | fill_value=1, 204 | num_nodes=num_nodes) 205 | mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() 206 | diff_matrix = kwargs['alpha'] * torch.inverse(mat) 207 | 208 | elif method == 'heat': 209 | # exp(t (A - I_n)) 210 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 211 | fill_value=-1, 212 | num_nodes=num_nodes) 213 | edge_weight = kwargs['t'] * edge_weight 214 | mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() 215 | undirected = is_undirected(edge_index, edge_weight, num_nodes) 216 | diff_matrix = self.__expm__(mat, undirected) 217 | 218 | elif method == 'coeff': 219 | adj_matrix = to_dense_adj(edge_index, 220 | edge_attr=edge_weight).squeeze() 221 | mat = torch.eye(num_nodes, device=edge_index.device) 222 | 223 | diff_matrix = kwargs['coeffs'][0] * mat 224 | for coeff in kwargs['coeffs'][1:]: 225 | mat = mat @ adj_matrix 226 | diff_matrix += coeff * mat 227 | else: 228 | raise ValueError('Exact GDC diffusion {} unknown.'.format(method)) 229 | 230 | return diff_matrix 231 | 232 | def diffusion_matrix_approx(self, edge_index, edge_weight, num_nodes, 233 | normalization, method, **kwargs): 234 | r"""Calculate the approximate, sparse diffusion on a given sparse 235 | graph. 236 | Args: 237 | edge_index (LongTensor): The edge indices. 238 | edge_weight (Tensor): One-dimensional edge weights. 239 | num_nodes (int): Number of nodes. 240 | normalization (str): Transition matrix normalization scheme 241 | (:obj:`"sym"`, :obj:`"row"`, or :obj:`"col"`). 242 | See :func:`GDC.transition_matrix` for details. 243 | method (str): Diffusion method: 244 | 1. :obj:`"ppr"`: Use personalized PageRank as diffusion. 245 | Additionally expects the parameters: 246 | - **alpha** (*float*) - Return probability in PPR. 247 | Commonly lies in :obj:`[0.05, 0.2]`. 248 | - **eps** (*float*) - Threshold for PPR calculation stopping 249 | criterion (:obj:`edge_weight >= eps * out_degree`). 250 | Recommended default: :obj:`1e-4`. 251 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 252 | """ 253 | if method == 'ppr': 254 | if normalization == 'sym': 255 | # Calculate original degrees. 256 | _, col = edge_index 257 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 258 | 259 | edge_index_np = edge_index.cpu().numpy() 260 | # Assumes coalesced edge_index. 261 | _, indptr, out_degree = np.unique(edge_index_np[0], 262 | return_index=True, 263 | return_counts=True) 264 | 265 | neighbors, neighbor_weights = GDC.__calc_ppr__( 266 | indptr, edge_index_np[1], out_degree, kwargs['alpha'], 267 | kwargs['eps']) 268 | ppr_normalization = 'col' if normalization == 'col' else 'row' 269 | edge_index, edge_weight = self.__neighbors_to_graph__( 270 | neighbors, neighbor_weights, ppr_normalization, 271 | device=edge_index.device) 272 | edge_index = edge_index.to(torch.long) 273 | 274 | if normalization == 'sym': 275 | # We can change the normalization from row-normalized to 276 | # symmetric by multiplying the resulting matrix with D^{1/2} 277 | # from the left and D^{-1/2} from the right. 278 | # Since we use the original degrees for this it will be like 279 | # we had used symmetric normalization from the beginning 280 | # (except for errors due to approximation). 281 | row, col = edge_index 282 | deg_inv = deg.sqrt() 283 | deg_inv_sqrt = deg.pow(-0.5) 284 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 285 | edge_weight = deg_inv[row] * edge_weight * deg_inv_sqrt[col] 286 | elif normalization in ['col', 'row']: 287 | pass 288 | else: 289 | raise ValueError( 290 | ('Transition matrix normalization {} not implemented for ' 291 | 'non-exact GDC computation.').format(normalization)) 292 | 293 | elif method == 'heat': 294 | raise NotImplementedError( 295 | ('Currently no fast heat kernel is implemented. You are ' 296 | 'welcome to create one yourself, e.g., based on ' 297 | '"Kloster and Gleich: Heat kernel based community detection ' 298 | '(KDD 2014)."')) 299 | else: 300 | raise ValueError( 301 | 'Approximate GDC diffusion {} unknown.'.format(method)) 302 | 303 | return edge_index, edge_weight 304 | 305 | def sparsify_dense(self, matrix, method, **kwargs): 306 | r"""Sparsifies the given dense matrix. 307 | Args: 308 | matrix (Tensor): Matrix to sparsify. 309 | num_nodes (int): Number of nodes. 310 | method (str): Method of sparsification. Options: 311 | 1. :obj:`"threshold"`: Remove all edges with weights smaller 312 | than :obj:`eps`. 313 | Additionally expects one of these parameters: 314 | - **eps** (*float*) - Threshold to bound edges at. 315 | - **avg_degree** (*int*) - If :obj:`eps` is not given, 316 | it can optionally be calculated by calculating the 317 | :obj:`eps` required to achieve a given :obj:`avg_degree`. 318 | 2. :obj:`"topk"`: Keep edges with top :obj:`k` edge weights per 319 | node (column). 320 | Additionally expects the following parameters: 321 | - **k** (*int*) - Specifies the number of edges to keep. 322 | - **dim** (*int*) - The axis along which to take the top 323 | :obj:`k`. 324 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 325 | """ 326 | assert matrix.shape[0] == matrix.shape[1] 327 | N = matrix.shape[1] 328 | 329 | if method == 'threshold': 330 | if 'eps' not in kwargs.keys(): 331 | kwargs['eps'] = self.__calculate_eps__(matrix, N, 332 | kwargs['avg_degree']) 333 | 334 | edge_index = torch.nonzero(matrix >= kwargs['eps']).t() 335 | edge_index_flat = edge_index[0] * N + edge_index[1] 336 | edge_weight = matrix.flatten()[edge_index_flat] 337 | 338 | elif method == 'topk': 339 | assert kwargs['dim'] in [0, 1] 340 | sort_idx = torch.argsort(matrix, dim=kwargs['dim'], 341 | descending=True) 342 | if kwargs['dim'] == 0: 343 | top_idx = sort_idx[:kwargs['k']] 344 | edge_weight = torch.gather(matrix, dim=kwargs['dim'], 345 | index=top_idx).flatten() 346 | 347 | row_idx = torch.arange(0, N, device=matrix.device).repeat( 348 | kwargs['k']) 349 | edge_index = torch.stack([top_idx.flatten(), row_idx], dim=0) 350 | else: 351 | top_idx = sort_idx[:, :kwargs['k']] 352 | edge_weight = torch.gather(matrix, dim=kwargs['dim'], 353 | index=top_idx).flatten() 354 | 355 | col_idx = torch.arange( 356 | 0, N, device=matrix.device).repeat_interleave(kwargs['k']) 357 | edge_index = torch.stack([col_idx, top_idx.flatten()], dim=0) 358 | else: 359 | raise ValueError('GDC sparsification {} unknown.'.format(method)) 360 | 361 | return edge_index, edge_weight 362 | 363 | def sparsify_sparse(self, edge_index, edge_weight, num_nodes, method, 364 | **kwargs): 365 | r"""Sparsifies a given sparse graph further. 366 | Args: 367 | edge_index (LongTensor): The edge indices. 368 | edge_weight (Tensor): One-dimensional edge weights. 369 | num_nodes (int): Number of nodes. 370 | method (str): Method of sparsification: 371 | 1. :obj:`"threshold"`: Remove all edges with weights smaller 372 | than :obj:`eps`. 373 | Additionally expects one of these parameters: 374 | - **eps** (*float*) - Threshold to bound edges at. 375 | - **avg_degree** (*int*) - If :obj:`eps` is not given, 376 | it can optionally be calculated by calculating the 377 | :obj:`eps` required to achieve a given :obj:`avg_degree`. 378 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 379 | """ 380 | if method == 'threshold': 381 | if 'eps' not in kwargs.keys(): 382 | kwargs['eps'] = self.__calculate_eps__(edge_weight, num_nodes, 383 | kwargs['avg_degree']) 384 | 385 | remaining_edge_idx = torch.nonzero( 386 | edge_weight >= kwargs['eps']).flatten() 387 | edge_index = edge_index[:, remaining_edge_idx] 388 | edge_weight = edge_weight[remaining_edge_idx] 389 | elif method == 'topk': 390 | raise NotImplementedError( 391 | 'Sparse topk sparsification not implemented.') 392 | else: 393 | raise ValueError('GDC sparsification {} unknown.'.format(method)) 394 | 395 | return edge_index, edge_weight 396 | 397 | def __expm__(self, matrix, symmetric): 398 | r"""Calculates matrix exponential. 399 | Args: 400 | matrix (Tensor): Matrix to take exponential of. 401 | symmetric (bool): Specifies whether the matrix is symmetric. 402 | :rtype: (:class:`Tensor`) 403 | """ 404 | if symmetric: 405 | e, V = torch.symeig(matrix, eigenvectors=True) 406 | diff_mat = V @ torch.diag(e.exp()) @ V.t() 407 | else: 408 | diff_mat_np = expm(matrix.cpu().numpy()) 409 | diff_mat = torch.Tensor(diff_mat_np).to(matrix.device) 410 | return diff_mat 411 | 412 | def __calculate_eps__(self, matrix, num_nodes, avg_degree): 413 | r"""Calculates threshold necessary to achieve a given average degree. 414 | Args: 415 | matrix (Tensor): Adjacency matrix or edge weights. 416 | num_nodes (int): Number of nodes. 417 | avg_degree (int): Target average degree. 418 | :rtype: (:class:`float`) 419 | """ 420 | sorted_edges = torch.sort(matrix.flatten(), descending=True).values 421 | if avg_degree * num_nodes > len(sorted_edges): 422 | return -np.inf 423 | return sorted_edges[avg_degree * num_nodes - 1] 424 | 425 | def __neighbors_to_graph__(self, neighbors, neighbor_weights, 426 | normalization='row', device='cpu'): 427 | r"""Combine a list of neighbors and neighbor weights to create a sparse 428 | graph. 429 | Args: 430 | neighbors (List[List[int]]): List of neighbors for each node. 431 | neighbor_weights (List[List[float]]): List of weights for the 432 | neighbors of each node. 433 | normalization (str): Normalization of resulting matrix 434 | (options: :obj:`"row"`, :obj:`"col"`). (default: :obj:`"row"`) 435 | device (torch.device): Device to create output tensors on. 436 | (default: :obj:`"cpu"`) 437 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 438 | """ 439 | edge_weight = torch.Tensor(np.concatenate(neighbor_weights)).to(device) 440 | i = np.repeat(np.arange(len(neighbors)), 441 | np.fromiter(map(len, neighbors), dtype=np.int)) 442 | j = np.concatenate(neighbors) 443 | if normalization == 'col': 444 | edge_index = torch.Tensor(np.vstack([j, i])).to(device) 445 | N = len(neighbors) 446 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) 447 | elif normalization == 'row': 448 | edge_index = torch.Tensor(np.vstack([i, j])).to(device) 449 | else: 450 | raise ValueError( 451 | f"PPR matrix normalization {normalization} unknown.") 452 | return edge_index, edge_weight 453 | 454 | @staticmethod 455 | @jit() 456 | def __calc_ppr__(indptr, indices, out_degree, alpha, eps): 457 | r"""Calculate the personalized PageRank vector for all nodes 458 | using a variant of the Andersen algorithm 459 | (see Andersen et al. :Local Graph Partitioning using PageRank Vectors.) 460 | Args: 461 | indptr (np.ndarray): Index pointer for the sparse matrix 462 | (CSR-format). 463 | indices (np.ndarray): Indices of the sparse matrix entries 464 | (CSR-format). 465 | out_degree (np.ndarray): Out-degree of each node. 466 | alpha (float): Alpha of the PageRank to calculate. 467 | eps (float): Threshold for PPR calculation stopping criterion 468 | (:obj:`edge_weight >= eps * out_degree`). 469 | :rtype: (:class:`List[List[int]]`, :class:`List[List[float]]`) 470 | """ 471 | alpha_eps = alpha * eps 472 | js = [] 473 | vals = [] 474 | for inode in range(len(out_degree)): 475 | p = {inode: 0.0} 476 | r = {} 477 | r[inode] = alpha 478 | q = [inode] 479 | while len(q) > 0: 480 | unode = q.pop() 481 | 482 | res = r[unode] if unode in r else 0 483 | if unode in p: 484 | p[unode] += res 485 | else: 486 | p[unode] = res 487 | r[unode] = 0 488 | for vnode in indices[indptr[unode]:indptr[unode + 1]]: 489 | _val = (1 - alpha) * res / out_degree[unode] 490 | if vnode in r: 491 | r[vnode] += _val 492 | else: 493 | r[vnode] = _val 494 | 495 | res_vnode = r[vnode] if vnode in r else 0 496 | if res_vnode >= alpha_eps * out_degree[vnode]: 497 | if vnode not in q: 498 | q.append(vnode) 499 | js.append(list(p.keys())) 500 | vals.append(list(p.values())) 501 | return js, vals 502 | 503 | def __repr__(self): 504 | return '{}()'.format(self.__class__.__name__) -------------------------------------------------------------------------------- /imports/preprocess_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Mwiza Kunda 2 | # Copyright (C) 2017 Sarah Parisot , Sofia Ira Ktena 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implcd ied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | 18 | import os 19 | import warnings 20 | import glob 21 | import csv 22 | import re 23 | import numpy as np 24 | import scipy.io as sio 25 | import sys 26 | from nilearn import connectome 27 | import pandas as pd 28 | from scipy.spatial import distance 29 | from scipy import signal 30 | from sklearn.compose import ColumnTransformer 31 | from sklearn.preprocessing import Normalizer 32 | from sklearn.preprocessing import OrdinalEncoder 33 | from sklearn.preprocessing import OneHotEncoder 34 | from sklearn.preprocessing import StandardScaler 35 | warnings.filterwarnings("ignore") 36 | 37 | # Input data variables 38 | 39 | root_folder = '/home/azureuser/projects/BrainGNN/data/' 40 | data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal') 41 | phenotype = os.path.join(root_folder, 'ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv') 42 | 43 | 44 | def fetch_filenames(subject_IDs, file_type, atlas): 45 | """ 46 | subject_list : list of short subject IDs in string format 47 | file_type : must be one of the available file types 48 | filemapping : resulting file name format 49 | returns: 50 | filenames : list of filetypes (same length as subject_list) 51 | """ 52 | 53 | filemapping = {'func_preproc': '_func_preproc.nii.gz', 54 | 'rois_' + atlas: '_rois_' + atlas + '.1D'} 55 | # The list to be filled 56 | filenames = [] 57 | 58 | # Fill list with requested file paths 59 | for i in range(len(subject_IDs)): 60 | os.chdir(data_folder) 61 | try: 62 | try: 63 | os.chdir(data_folder) 64 | filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0]) 65 | except: 66 | os.chdir(data_folder + '/' + subject_IDs[i]) 67 | filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0]) 68 | except IndexError: 69 | filenames.append('N/A') 70 | return filenames 71 | 72 | 73 | # Get timeseries arrays for list of subjects 74 | def get_timeseries(subject_list, atlas_name, silence=False): 75 | """ 76 | subject_list : list of short subject IDs in string format 77 | atlas_name : the atlas based on which the timeseries are generated e.g. aal, cc200 78 | returns: 79 | time_series : list of timeseries arrays, each of shape (timepoints x regions) 80 | """ 81 | 82 | timeseries = [] 83 | for i in range(len(subject_list)): 84 | subject_folder = os.path.join(data_folder, subject_list[i]) 85 | ro_file = [f for f in os.listdir(subject_folder) if f.endswith('_rois_' + atlas_name + '.1D')] 86 | fl = os.path.join(subject_folder, ro_file[0]) 87 | if silence != True: 88 | print("Reading timeseries file %s" % fl) 89 | timeseries.append(np.loadtxt(fl, skiprows=0)) 90 | 91 | return timeseries 92 | 93 | 94 | # compute connectivity matrices 95 | def subject_connectivity(timeseries, subjects, atlas_name, kind, iter_no='', seed=1234, 96 | n_subjects='', save=True, save_path=data_folder): 97 | """ 98 | timeseries : timeseries table for subject (timepoints x regions) 99 | subjects : subject IDs 100 | atlas_name : name of the parcellation atlas used 101 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation 102 | iter_no : tangent connectivity iteration number for cross validation evaluation 103 | save : save the connectivity matrix to a file 104 | save_path : specify path to save the matrix if different from subject folder 105 | returns: 106 | connectivity : connectivity matrix (regions x regions) 107 | """ 108 | 109 | if kind in ['TPE', 'TE', 'correlation','partial correlation']: 110 | if kind not in ['TPE', 'TE']: 111 | conn_measure = connectome.ConnectivityMeasure(kind=kind) 112 | connectivity = conn_measure.fit_transform(timeseries) 113 | else: 114 | if kind == 'TPE': 115 | conn_measure = connectome.ConnectivityMeasure(kind='correlation') 116 | conn_mat = conn_measure.fit_transform(timeseries) 117 | conn_measure = connectome.ConnectivityMeasure(kind='tangent') 118 | connectivity_fit = conn_measure.fit(conn_mat) 119 | connectivity = connectivity_fit.transform(conn_mat) 120 | else: 121 | conn_measure = connectome.ConnectivityMeasure(kind='tangent') 122 | connectivity_fit = conn_measure.fit(timeseries) 123 | connectivity = connectivity_fit.transform(timeseries) 124 | 125 | if save: 126 | if kind not in ['TPE', 'TE']: 127 | for i, subj_id in enumerate(subjects): 128 | subject_file = os.path.join(save_path, subj_id, 129 | subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '.mat') 130 | sio.savemat(subject_file, {'connectivity': connectivity[i]}) 131 | return connectivity 132 | else: 133 | for i, subj_id in enumerate(subjects): 134 | subject_file = os.path.join(save_path, subj_id, 135 | subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '_' + str( 136 | iter_no) + '_' + str(seed) + '_' + validation_ext + str( 137 | n_subjects) + '.mat') 138 | sio.savemat(subject_file, {'connectivity': connectivity[i]}) 139 | return connectivity_fit 140 | 141 | 142 | # Get the list of subject IDs 143 | 144 | def get_ids(num_subjects=None): 145 | """ 146 | return: 147 | subject_IDs : list of all subject IDs 148 | """ 149 | 150 | subject_IDs = np.genfromtxt(os.path.join(data_folder, 'subject_IDs.txt'), dtype=str) 151 | 152 | if num_subjects is not None: 153 | subject_IDs = subject_IDs[:num_subjects] 154 | 155 | return subject_IDs 156 | 157 | 158 | # Get phenotype values for a list of subjects 159 | def get_subject_score(subject_list, score): 160 | scores_dict = {} 161 | 162 | with open(phenotype) as csv_file: 163 | reader = csv.DictReader(csv_file) 164 | for row in reader: 165 | if row['SUB_ID'] in subject_list: 166 | if score == 'HANDEDNESS_CATEGORY': 167 | if (row[score].strip() == '-9999') or (row[score].strip() == ''): 168 | scores_dict[row['SUB_ID']] = 'R' 169 | elif row[score] == 'Mixed': 170 | scores_dict[row['SUB_ID']] = 'Ambi' 171 | elif row[score] == 'L->R': 172 | scores_dict[row['SUB_ID']] = 'Ambi' 173 | else: 174 | scores_dict[row['SUB_ID']] = row[score] 175 | elif (score == 'FIQ' or score == 'PIQ' or score == 'VIQ'): 176 | if (row[score].strip() == '-9999') or (row[score].strip() == ''): 177 | scores_dict[row['SUB_ID']] = 100 178 | else: 179 | scores_dict[row['SUB_ID']] = float(row[score]) 180 | 181 | else: 182 | scores_dict[row['SUB_ID']] = row[score] 183 | 184 | return scores_dict 185 | 186 | 187 | # preprocess phenotypes. Categorical -> ordinal representation 188 | def preprocess_phenotypes(pheno_ft, params): 189 | if params['model'] == 'MIDA': 190 | ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2])], remainder='passthrough') 191 | else: 192 | ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2, 3])], remainder='passthrough') 193 | 194 | pheno_ft = ct.fit_transform(pheno_ft) 195 | pheno_ft = pheno_ft.astype('float32') 196 | 197 | return (pheno_ft) 198 | 199 | 200 | # create phenotype feature vector to concatenate with fmri feature vectors 201 | def phenotype_ft_vector(pheno_ft, num_subjects, params): 202 | gender = pheno_ft[:, 0] 203 | if params['model'] == 'MIDA': 204 | eye = pheno_ft[:, 0] 205 | hand = pheno_ft[:, 2] 206 | age = pheno_ft[:, 3] 207 | fiq = pheno_ft[:, 4] 208 | else: 209 | eye = pheno_ft[:, 2] 210 | hand = pheno_ft[:, 3] 211 | age = pheno_ft[:, 4] 212 | fiq = pheno_ft[:, 5] 213 | 214 | phenotype_ft = np.zeros((num_subjects, 4)) 215 | phenotype_ft_eye = np.zeros((num_subjects, 2)) 216 | phenotype_ft_hand = np.zeros((num_subjects, 3)) 217 | 218 | for i in range(num_subjects): 219 | phenotype_ft[i, int(gender[i])] = 1 220 | phenotype_ft[i, -2] = age[i] 221 | phenotype_ft[i, -1] = fiq[i] 222 | phenotype_ft_eye[i, int(eye[i])] = 1 223 | phenotype_ft_hand[i, int(hand[i])] = 1 224 | 225 | if params['model'] == 'MIDA': 226 | phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand], axis=1) 227 | else: 228 | phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand, phenotype_ft_eye], axis=1) 229 | 230 | return phenotype_ft 231 | 232 | 233 | # Load precomputed fMRI connectivity networks 234 | def get_networks(subject_list, kind, iter_no='', seed=1234, n_subjects='', atlas_name="aal", 235 | variable='connectivity'): 236 | """ 237 | subject_list : list of subject IDs 238 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation 239 | atlas_name : name of the parcellation atlas used 240 | variable : variable name in the .mat file that has been used to save the precomputed networks 241 | return: 242 | matrix : feature matrix of connectivity networks (num_subjects x network_size) 243 | """ 244 | 245 | all_networks = [] 246 | for subject in subject_list: 247 | if len(kind.split()) == 2: 248 | kind = '_'.join(kind.split()) 249 | fl = os.path.join(data_folder, subject, 250 | subject + "_" + atlas_name + "_" + kind.replace(' ', '_') + ".mat") 251 | 252 | 253 | matrix = sio.loadmat(fl)[variable] 254 | all_networks.append(matrix) 255 | 256 | if kind in ['TE', 'TPE']: 257 | norm_networks = [mat for mat in all_networks] 258 | else: 259 | norm_networks = [np.arctanh(mat) for mat in all_networks] 260 | 261 | networks = np.stack(norm_networks) 262 | 263 | return networks 264 | 265 | -------------------------------------------------------------------------------- /imports/read_abide_stats_parall.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Xiaoxiao Li 3 | Date: 2019/02/24 4 | ''' 5 | 6 | import os.path as osp 7 | from os import listdir 8 | import os 9 | import glob 10 | import h5py 11 | 12 | import torch 13 | import numpy as np 14 | from scipy.io import loadmat 15 | from torch_geometric.data import Data 16 | import networkx as nx 17 | from networkx.convert_matrix import from_numpy_matrix 18 | import multiprocessing 19 | from torch_sparse import coalesce 20 | from torch_geometric.utils import remove_self_loops 21 | from functools import partial 22 | import deepdish as dd 23 | from imports.gdc import GDC 24 | 25 | 26 | def split(data, batch): 27 | node_slice = torch.cumsum(torch.from_numpy(np.bincount(batch)), 0) 28 | node_slice = torch.cat([torch.tensor([0]), node_slice]) 29 | 30 | row, _ = data.edge_index 31 | edge_slice = torch.cumsum(torch.from_numpy(np.bincount(batch[row])), 0) 32 | edge_slice = torch.cat([torch.tensor([0]), edge_slice]) 33 | 34 | # Edge indices should start at zero for every graph. 35 | data.edge_index -= node_slice[batch[row]].unsqueeze(0) 36 | 37 | slices = {'edge_index': edge_slice} 38 | if data.x is not None: 39 | slices['x'] = node_slice 40 | if data.edge_attr is not None: 41 | slices['edge_attr'] = edge_slice 42 | if data.y is not None: 43 | if data.y.size(0) == batch.size(0): 44 | slices['y'] = node_slice 45 | else: 46 | slices['y'] = torch.arange(0, batch[-1] + 2, dtype=torch.long) 47 | if data.pos is not None: 48 | slices['pos'] = node_slice 49 | 50 | return data, slices 51 | 52 | 53 | def cat(seq): 54 | seq = [item for item in seq if item is not None] 55 | seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq] 56 | return torch.cat(seq, dim=-1).squeeze() if len(seq) > 0 else None 57 | 58 | class NoDaemonProcess(multiprocessing.Process): 59 | @property 60 | def daemon(self): 61 | return False 62 | 63 | @daemon.setter 64 | def daemon(self, value): 65 | pass 66 | 67 | 68 | class NoDaemonContext(type(multiprocessing.get_context())): 69 | Process = NoDaemonProcess 70 | 71 | 72 | def read_data(data_dir): 73 | onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))] 74 | onlyfiles.sort() 75 | batch = [] 76 | pseudo = [] 77 | y_list = [] 78 | edge_att_list, edge_index_list,att_list = [], [], [] 79 | 80 | # parallar computing 81 | cores = multiprocessing.cpu_count() 82 | pool = multiprocessing.Pool(processes=cores) 83 | #pool = MyPool(processes = cores) 84 | func = partial(read_sigle_data, data_dir) 85 | 86 | import timeit 87 | 88 | start = timeit.default_timer() 89 | 90 | res = pool.map(func, onlyfiles) 91 | 92 | pool.close() 93 | pool.join() 94 | 95 | stop = timeit.default_timer() 96 | 97 | print('Time: ', stop - start) 98 | 99 | 100 | 101 | for j in range(len(res)): 102 | edge_att_list.append(res[j][0]) 103 | edge_index_list.append(res[j][1]+j*res[j][4]) 104 | att_list.append(res[j][2]) 105 | y_list.append(res[j][3]) 106 | batch.append([j]*res[j][4]) 107 | pseudo.append(np.diag(np.ones(res[j][4]))) 108 | 109 | edge_att_arr = np.concatenate(edge_att_list) 110 | edge_index_arr = np.concatenate(edge_index_list, axis=1) 111 | att_arr = np.concatenate(att_list, axis=0) 112 | pseudo_arr = np.concatenate(pseudo, axis=0) 113 | y_arr = np.stack(y_list) 114 | edge_att_torch = torch.from_numpy(edge_att_arr.reshape(len(edge_att_arr), 1)).float() 115 | att_torch = torch.from_numpy(att_arr).float() 116 | y_torch = torch.from_numpy(y_arr).long() # classification 117 | batch_torch = torch.from_numpy(np.hstack(batch)).long() 118 | edge_index_torch = torch.from_numpy(edge_index_arr).long() 119 | pseudo_torch = torch.from_numpy(pseudo_arr).float() 120 | data = Data(x=att_torch, edge_index=edge_index_torch, y=y_torch, edge_attr=edge_att_torch, pos = pseudo_torch ) 121 | 122 | 123 | data, slices = split(data, batch_torch) 124 | 125 | return data, slices 126 | 127 | 128 | def read_sigle_data(data_dir,filename,use_gdc =False): 129 | 130 | temp = dd.io.load(osp.join(data_dir, filename)) 131 | 132 | # read edge and edge attribute 133 | pcorr = np.abs(temp['pcorr'][()]) 134 | 135 | num_nodes = pcorr.shape[0] 136 | G = from_numpy_matrix(pcorr) 137 | A = nx.to_scipy_sparse_matrix(G) 138 | adj = A.tocoo() 139 | edge_att = np.zeros(len(adj.row)) 140 | for i in range(len(adj.row)): 141 | edge_att[i] = pcorr[adj.row[i], adj.col[i]] 142 | 143 | edge_index = np.stack([adj.row, adj.col]) 144 | edge_index, edge_att = remove_self_loops(torch.from_numpy(edge_index), torch.from_numpy(edge_att)) 145 | edge_index = edge_index.long() 146 | edge_index, edge_att = coalesce(edge_index, edge_att, num_nodes, 147 | num_nodes) 148 | att = temp['corr'][()] 149 | label = temp['label'][()] 150 | 151 | att_torch = torch.from_numpy(att).float() 152 | y_torch = torch.from_numpy(np.array(label)).long() # classification 153 | 154 | data = Data(x=att_torch, edge_index=edge_index.long(), y=y_torch, edge_attr=edge_att) 155 | 156 | if use_gdc: 157 | ''' 158 | Implementation of https://papers.nips.cc/paper/2019/hash/23c894276a2c5a16470e6a31f4618d73-Abstract.html 159 | ''' 160 | data.edge_attr = data.edge_attr.squeeze() 161 | gdc = GDC(self_loop_weight=1, normalization_in='sym', 162 | normalization_out='col', 163 | diffusion_kwargs=dict(method='ppr', alpha=0.2), 164 | sparsification_kwargs=dict(method='topk', k=20, 165 | dim=0), exact=True) 166 | data = gdc(data) 167 | return data.edge_attr.data.numpy(),data.edge_index.data.numpy(),data.x.data.numpy(),data.y.data.item(),num_nodes 168 | 169 | else: 170 | return edge_att.data.numpy(),edge_index.data.numpy(),att,label,num_nodes 171 | 172 | if __name__ == "__main__": 173 | data_dir = '/home/azureuser/projects/BrainGNN/data/ABIDE_pcp/cpac/filt_noglobal/raw' 174 | filename = '50346.h5' 175 | read_sigle_data(data_dir, filename) 176 | 177 | 178 | 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /imports/utils.py: -------------------------------------------------------------------------------- 1 | from scipy import stats 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | from scipy.io import loadmat 6 | from sklearn.model_selection import StratifiedKFold 7 | from sklearn.model_selection import KFold 8 | 9 | 10 | def train_val_test_split(kfold = 5, fold = 0): 11 | n_sub = 1035 12 | id = list(range(n_sub)) 13 | 14 | 15 | import random 16 | random.seed(123) 17 | random.shuffle(id) 18 | 19 | kf = KFold(n_splits=kfold, random_state=123,shuffle = True) 20 | kf2 = KFold(n_splits=kfold-1, shuffle=True, random_state = 666) 21 | 22 | 23 | test_index = list() 24 | train_index = list() 25 | val_index = list() 26 | 27 | for tr,te in kf.split(np.array(id)): 28 | test_index.append(te) 29 | tr, val = list(kf2.split(tr))[0] 30 | train_index.append(tr) 31 | val_index.append(val) 32 | 33 | train_id = train_index[fold] 34 | test_id = test_index[fold] 35 | val_id = val_index[fold] 36 | 37 | return train_id,val_id,test_id -------------------------------------------------------------------------------- /net/braingnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch_geometric.nn import TopKPooling 5 | from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp 6 | from torch_geometric.utils import (add_self_loops, sort_edge_index, 7 | remove_self_loops) 8 | from torch_sparse import spspmm 9 | 10 | from net.braingraphconv import MyNNConv 11 | 12 | 13 | ########################################################################################################################## 14 | class Network(torch.nn.Module): 15 | def __init__(self, indim, ratio, nclass, k=8, R=200): 16 | ''' 17 | 18 | :param indim: (int) node feature dimension 19 | :param ratio: (float) pooling ratio in (0,1) 20 | :param nclass: (int) number of classes 21 | :param k: (int) number of communities 22 | :param R: (int) number of ROIs 23 | ''' 24 | super(Network, self).__init__() 25 | 26 | self.indim = indim 27 | self.dim1 = 32 28 | self.dim2 = 32 29 | self.dim3 = 512 30 | self.dim4 = 256 31 | self.dim5 = 8 32 | self.k = k 33 | self.R = R 34 | 35 | self.n1 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim1 * self.indim)) 36 | self.conv1 = MyNNConv(self.indim, self.dim1, self.n1, normalize=False) 37 | self.pool1 = TopKPooling(self.dim1, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid) 38 | self.n2 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim2 * self.dim1)) 39 | self.conv2 = MyNNConv(self.dim1, self.dim2, self.n2, normalize=False) 40 | self.pool2 = TopKPooling(self.dim2, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid) 41 | 42 | #self.fc1 = torch.nn.Linear((self.dim2) * 2, self.dim2) 43 | self.fc1 = torch.nn.Linear((self.dim1+self.dim2)*2, self.dim2) 44 | self.bn1 = torch.nn.BatchNorm1d(self.dim2) 45 | self.fc2 = torch.nn.Linear(self.dim2, self.dim3) 46 | self.bn2 = torch.nn.BatchNorm1d(self.dim3) 47 | self.fc3 = torch.nn.Linear(self.dim3, nclass) 48 | 49 | 50 | 51 | 52 | def forward(self, x, edge_index, batch, edge_attr, pos): 53 | 54 | x = self.conv1(x, edge_index, edge_attr, pos) 55 | x, edge_index, edge_attr, batch, perm, score1 = self.pool1(x, edge_index, edge_attr, batch) 56 | 57 | pos = pos[perm] 58 | x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 59 | 60 | edge_attr = edge_attr.squeeze() 61 | edge_index, edge_attr = self.augment_adj(edge_index, edge_attr, x.size(0)) 62 | 63 | x = self.conv2(x, edge_index, edge_attr, pos) 64 | x, edge_index, edge_attr, batch, perm, score2 = self.pool2(x, edge_index,edge_attr, batch) 65 | 66 | x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 67 | 68 | x = torch.cat([x1,x2], dim=1) 69 | x = self.bn1(F.relu(self.fc1(x))) 70 | x = F.dropout(x, p=0.5, training=self.training) 71 | x = self.bn2(F.relu(self.fc2(x))) 72 | x= F.dropout(x, p=0.5, training=self.training) 73 | x = F.log_softmax(self.fc3(x), dim=-1) 74 | 75 | return x,self.pool1.weight,self.pool2.weight, torch.sigmoid(score1).view(x.size(0),-1), torch.sigmoid(score2).view(x.size(0),-1) 76 | 77 | def augment_adj(self, edge_index, edge_weight, num_nodes): 78 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 79 | num_nodes=num_nodes) 80 | edge_index, edge_weight = sort_edge_index(edge_index, edge_weight, 81 | num_nodes) 82 | edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index, 83 | edge_weight, num_nodes, num_nodes, 84 | num_nodes) 85 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 86 | return edge_index, edge_weight 87 | 88 | -------------------------------------------------------------------------------- /net/braingraphconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Parameter 4 | from net.brainmsgpassing import MyMessagePassing 5 | from torch_geometric.utils import add_remaining_self_loops,softmax 6 | 7 | from torch_geometric.typing import (OptTensor) 8 | 9 | from net.inits import uniform 10 | 11 | 12 | class MyNNConv(MyMessagePassing): 13 | def __init__(self, in_channels, out_channels, nn, normalize=False, bias=True, 14 | **kwargs): 15 | super(MyNNConv, self).__init__(aggr='mean', **kwargs) 16 | 17 | self.in_channels = in_channels 18 | self.out_channels = out_channels 19 | self.normalize = normalize 20 | self.nn = nn 21 | #self.weight = Parameter(torch.Tensor(self.in_channels, out_channels)) 22 | 23 | if bias: 24 | self.bias = Parameter(torch.Tensor(out_channels)) 25 | else: 26 | self.register_parameter('bias', None) 27 | 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | # uniform(self.in_channels, self.weight) 32 | uniform(self.in_channels, self.bias) 33 | 34 | def forward(self, x, edge_index, edge_weight=None, pseudo= None, size=None): 35 | """""" 36 | edge_weight = edge_weight.squeeze() 37 | if size is None and torch.is_tensor(x): 38 | edge_index, edge_weight = add_remaining_self_loops( 39 | edge_index, edge_weight, 1, x.size(0)) 40 | 41 | weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels) 42 | if torch.is_tensor(x): 43 | x = torch.matmul(x.unsqueeze(1), weight).squeeze(1) 44 | else: 45 | x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1), 46 | None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1)) 47 | 48 | # weight = self.nn(pseudo).view(-1, self.out_channels,self.in_channels) 49 | # if torch.is_tensor(x): 50 | # x = torch.matmul(x.unsqueeze(1), weight.permute(0,2,1)).squeeze(1) 51 | # else: 52 | # x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1), 53 | # None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1)) 54 | 55 | return self.propagate(edge_index, size=size, x=x, 56 | edge_weight=edge_weight) 57 | 58 | def message(self, edge_index_i, size_i, x_j, edge_weight, ptr: OptTensor): 59 | edge_weight = softmax(edge_weight, edge_index_i, ptr, size_i) 60 | return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j 61 | 62 | def update(self, aggr_out): 63 | if self.bias is not None: 64 | aggr_out = aggr_out + self.bias 65 | if self.normalize: 66 | aggr_out = F.normalize(aggr_out, p=2, dim=-1) 67 | return aggr_out 68 | 69 | def __repr__(self): 70 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 71 | self.out_channels) 72 | 73 | -------------------------------------------------------------------------------- /net/brainmsgpassing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | 4 | import torch 5 | # from torch_geometric.utils import scatter_ 6 | from torch_scatter import scatter,scatter_add 7 | 8 | special_args = [ 9 | 'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j' 10 | ] 11 | __size_error_msg__ = ('All tensors which should get mapped to the same source ' 12 | 'or target nodes must be of same size in dimension 0.') 13 | 14 | is_python2 = sys.version_info[0] < 3 15 | getargspec = inspect.getargspec if is_python2 else inspect.getfullargspec 16 | 17 | 18 | class MyMessagePassing(torch.nn.Module): 19 | r"""Base class for creating message passing layers 20 | .. math:: 21 | \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, 22 | \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} 23 | \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right), 24 | where :math:`\square` denotes a differentiable, permutation invariant 25 | function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}` 26 | and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as 27 | MLPs. 28 | See `here `__ for the accompanying tutorial. 30 | Args: 31 | aggr (string, optional): The aggregation scheme to use 32 | (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`). 33 | (default: :obj:`"add"`) 34 | flow (string, optional): The flow direction of message passing 35 | (:obj:`"source_to_target"` or :obj:`"target_to_source"`). 36 | (default: :obj:`"source_to_target"`) 37 | node_dim (int, optional): The axis along which to propagate. 38 | (default: :obj:`0`) 39 | """ 40 | def __init__(self, aggr='add', flow='source_to_target', node_dim=0): 41 | super(MyMessagePassing, self).__init__() 42 | 43 | self.aggr = aggr 44 | assert self.aggr in ['add', 'mean', 'max'] 45 | 46 | self.flow = flow 47 | assert self.flow in ['source_to_target', 'target_to_source'] 48 | 49 | self.node_dim = node_dim 50 | assert self.node_dim >= 0 51 | 52 | self.__message_args__ = getargspec(self.message)[0][1:] 53 | self.__special_args__ = [(i, arg) 54 | for i, arg in enumerate(self.__message_args__) 55 | if arg in special_args] 56 | self.__message_args__ = [ 57 | arg for arg in self.__message_args__ if arg not in special_args 58 | ] 59 | self.__update_args__ = getargspec(self.update)[0][2:] 60 | 61 | def propagate(self, edge_index, size=None, **kwargs): 62 | r"""The initial call to start propagating messages. 63 | Args: 64 | edge_index (Tensor): The indices of a general (sparse) assignment 65 | matrix with shape :obj:`[N, M]` (can be directed or 66 | undirected). 67 | size (list or tuple, optional): The size :obj:`[N, M]` of the 68 | assignment matrix. If set to :obj:`None`, the size is tried to 69 | get automatically inferred and assumed to be symmetric. 70 | (default: :obj:`None`) 71 | **kwargs: Any additional data which is needed to construct messages 72 | and to update node embeddings. 73 | """ 74 | 75 | dim = self.node_dim 76 | size = [None, None] if size is None else list(size) 77 | assert len(size) == 2 78 | 79 | i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0) 80 | ij = {"_i": i, "_j": j} 81 | 82 | message_args = [] 83 | for arg in self.__message_args__: 84 | if arg[-2:] in ij.keys(): 85 | tmp = kwargs.get(arg[:-2], None) 86 | if tmp is None: # pragma: no cover 87 | message_args.append(tmp) 88 | else: 89 | idx = ij[arg[-2:]] 90 | if isinstance(tmp, tuple) or isinstance(tmp, list): 91 | assert len(tmp) == 2 92 | if tmp[1 - idx] is not None: 93 | if size[1 - idx] is None: 94 | size[1 - idx] = tmp[1 - idx].size(dim) 95 | if size[1 - idx] != tmp[1 - idx].size(dim): 96 | raise ValueError(__size_error_msg__) 97 | tmp = tmp[idx] 98 | 99 | if tmp is None: 100 | message_args.append(tmp) 101 | else: 102 | if size[idx] is None: 103 | size[idx] = tmp.size(dim) 104 | if size[idx] != tmp.size(dim): 105 | raise ValueError(__size_error_msg__) 106 | 107 | tmp = torch.index_select(tmp, dim, edge_index[idx]) 108 | message_args.append(tmp) 109 | else: 110 | message_args.append(kwargs.get(arg, None)) 111 | 112 | size[0] = size[1] if size[0] is None else size[0] 113 | size[1] = size[0] if size[1] is None else size[1] 114 | 115 | kwargs['edge_index'] = edge_index 116 | kwargs['size'] = size 117 | 118 | for (idx, arg) in self.__special_args__: 119 | if arg[-2:] in ij.keys(): 120 | message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]]) 121 | else: 122 | message_args.insert(idx, kwargs[arg]) 123 | 124 | update_args = [kwargs[arg] for arg in self.__update_args__] 125 | 126 | out = self.message(*message_args) 127 | # out = scatter_(self.aggr, out, edge_index[i], dim, dim_size=size[i]) 128 | out = scatter_add(out, edge_index[i], dim, dim_size=size[i]) 129 | out = self.update(out, *update_args) 130 | 131 | return out 132 | 133 | def message(self, x_j): # pragma: no cover 134 | r"""Constructs messages to node :math:`i` in analogy to 135 | :math:`\phi_{\mathbf{\Theta}}` for each edge in 136 | :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and 137 | :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`. 138 | Can take any argument which was initially passed to :meth:`propagate`. 139 | In addition, tensors passed to :meth:`propagate` can be mapped to the 140 | respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or 141 | :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`. 142 | """ 143 | 144 | return x_j 145 | 146 | def update(self, aggr_out): # pragma: no cover 147 | r"""Updates node embeddings in analogy to 148 | :math:`\gamma_{\mathbf{\Theta}}` for each node 149 | :math:`i \in \mathcal{V}`. 150 | Takes in the output of aggregation as first argument and any argument 151 | which was initially passed to :meth:`propagate`.""" 152 | 153 | return aggr_out 154 | -------------------------------------------------------------------------------- /net/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def uniform(size, tensor): 5 | bound = 1.0 / math.sqrt(size) 6 | if tensor is not None: 7 | tensor.data.uniform_(-bound, bound) 8 | 9 | 10 | def kaiming_uniform(tensor, fan, a): 11 | if tensor is not None: 12 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 13 | tensor.data.uniform_(-bound, bound) 14 | 15 | 16 | def glorot(tensor): 17 | if tensor is not None: 18 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 19 | tensor.data.uniform_(-stdv, stdv) 20 | 21 | 22 | def zeros(tensor): 23 | if tensor is not None: 24 | tensor.data.fill_(0) 25 | 26 | 27 | def ones(tensor): 28 | if tensor is not None: 29 | tensor.data.fill_(1) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster @ file:///home/ktietz/src/ci/alabaster_1611921544520/work 2 | anaconda-client==1.7.2 3 | anaconda-project @ file:///tmp/build/80754af9/anaconda-project_1610472525955/work 4 | anyio @ file:///tmp/build/80754af9/anyio_1617783275907/work/dist 5 | appdirs==1.4.4 6 | argh==0.26.2 7 | argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613037097816/work 8 | arrow==0.13.1 9 | ase==3.21.1 10 | asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work 11 | astroid @ file:///tmp/build/80754af9/astroid_1613500854201/work 12 | astropy @ file:///tmp/build/80754af9/astropy_1617745353437/work 13 | async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work 14 | atomicwrites==1.4.0 15 | attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work 16 | autopep8 @ file:///tmp/build/80754af9/autopep8_1615918855173/work 17 | Babel @ file:///tmp/build/80754af9/babel_1607110387436/work 18 | backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work 19 | backports.shutil-get-terminal-size @ file:///tmp/build/80754af9/backports.shutil_get_terminal_size_1608222128777/work 20 | beautifulsoup4 @ file:///home/linux1/recipes/ci/beautifulsoup4_1610988766420/work 21 | binaryornot @ file:///tmp/build/80754af9/binaryornot_1617751525010/work 22 | bitarray @ file:///tmp/build/80754af9/bitarray_1618431750766/work 23 | bkcharts==0.2 24 | black==19.10b0 25 | bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work 26 | bokeh @ file:///tmp/build/80754af9/bokeh_1617824541184/work 27 | boto==2.49.0 28 | Bottleneck==1.3.2 29 | brotlipy==0.7.0 30 | certifi==2020.12.5 31 | cffi @ file:///tmp/build/80754af9/cffi_1613246945912/work 32 | chardet @ file:///tmp/build/80754af9/chardet_1607706746162/work 33 | click @ file:///home/linux1/recipes/ci/click_1610990599742/work 34 | cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work 35 | clyent==1.2.2 36 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work 37 | contextlib2==0.6.0.post1 38 | cookiecutter @ file:///tmp/build/80754af9/cookiecutter_1617748928239/work 39 | cryptography @ file:///tmp/build/80754af9/cryptography_1616769286105/work 40 | cycler==0.10.0 41 | Cython @ file:///tmp/build/80754af9/cython_1618435160151/work 42 | cytoolz==0.11.0 43 | dask @ file:///tmp/build/80754af9/dask-core_1617390489108/work 44 | decorator @ file:///tmp/build/80754af9/decorator_1617916966915/work 45 | deepdish==0.3.6 46 | defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work 47 | diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work 48 | distributed @ file:///tmp/build/80754af9/distributed_1617381497899/work 49 | docutils @ file:///tmp/build/80754af9/docutils_1617624660125/work 50 | entrypoints==0.3 51 | et-xmlfile==1.0.1 52 | fastcache==1.1.0 53 | filelock @ file:///home/linux1/recipes/ci/filelock_1610993975404/work 54 | flake8 @ file:///tmp/build/80754af9/flake8_1615834841867/work 55 | Flask @ file:///home/ktietz/src/ci/flask_1611932660458/work 56 | fsspec @ file:///tmp/build/80754af9/fsspec_1617959894824/work 57 | future==0.18.2 58 | gevent @ file:///tmp/build/80754af9/gevent_1616770671827/work 59 | glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work 60 | gmpy2==2.0.8 61 | googledrivedownloader==0.4 62 | greenlet @ file:///tmp/build/80754af9/greenlet_1611957705398/work 63 | h5py==2.10.0 64 | HeapDict==1.0.1 65 | html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work 66 | idna @ file:///home/linux1/recipes/ci/idna_1610986105248/work 67 | imageio @ file:///tmp/build/80754af9/imageio_1617700267927/work 68 | imagesize @ file:///home/ktietz/src/ci/imagesize_1611921604382/work 69 | importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617874469820/work 70 | inflection==0.5.1 71 | iniconfig @ file:///home/linux1/recipes/ci/iniconfig_1610983019677/work 72 | intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work 73 | ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl 74 | ipython @ file:///tmp/build/80754af9/ipython_1617120885885/work 75 | ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work 76 | ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work 77 | isodate==0.6.0 78 | isort @ file:///tmp/build/80754af9/isort_1616355431277/work 79 | itsdangerous @ file:///home/ktietz/src/ci/itsdangerous_1611932585308/work 80 | jdcal==1.4.1 81 | jedi @ file:///tmp/build/80754af9/jedi_1606932564285/work 82 | jeepney @ file:///tmp/build/80754af9/jeepney_1606148855031/work 83 | Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work 84 | jinja2-time @ file:///tmp/build/80754af9/jinja2-time_1617751524098/work 85 | joblib @ file:///tmp/build/80754af9/joblib_1613502643832/work 86 | json5==0.9.5 87 | jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work 88 | jupyter==1.0.0 89 | jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work 90 | jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work 91 | jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213311222/work 92 | jupyter-packaging @ file:///tmp/build/80754af9/jupyter-packaging_1613502826984/work 93 | jupyter-server @ file:///tmp/build/80754af9/jupyter_server_1616083640759/work 94 | jupyterlab @ file:///tmp/build/80754af9/jupyterlab_1619133235951/work 95 | jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work 96 | jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1617134334258/work 97 | jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work 98 | keyring @ file:///tmp/build/80754af9/keyring_1614616740399/work 99 | kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work 100 | lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1616526917483/work 101 | libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work 102 | llvmlite==0.36.0 103 | locket==0.2.1 104 | lxml @ file:///tmp/build/80754af9/lxml_1616443220220/work 105 | MarkupSafe==1.1.1 106 | matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1613407855456/work 107 | mccabe==0.6.1 108 | mistune==0.8.4 109 | mkl-fft==1.3.0 110 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853849286/work 111 | mkl-service==2.3.0 112 | mock @ file:///tmp/build/80754af9/mock_1607622725907/work 113 | more-itertools @ file:///tmp/build/80754af9/more-itertools_1613676688952/work 114 | mpmath==1.2.1 115 | msgpack @ file:///tmp/build/80754af9/msgpack-python_1612287151062/work 116 | multipledispatch==0.6.0 117 | mypy-extensions==0.4.3 118 | nbclassic @ file:///tmp/build/80754af9/nbclassic_1616085367084/work 119 | nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work 120 | nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914830498/work 121 | nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work 122 | nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work 123 | networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work 124 | nibabel==3.2.1 125 | nilearn==0.7.1 126 | nltk @ file:///tmp/build/80754af9/nltk_1618327084230/work 127 | nose @ file:///tmp/build/80754af9/nose_1606773131901/work 128 | notebook @ file:///tmp/build/80754af9/notebook_1616443462982/work 129 | numba @ file:///tmp/build/80754af9/numba_1616774046117/work 130 | numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work 131 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1618497241363/work 132 | numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work 133 | olefile==0.46 134 | openpyxl @ file:///tmp/build/80754af9/openpyxl_1615411699337/work 135 | packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work 136 | pandas==1.2.4 137 | pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work 138 | parso==0.7.0 139 | partd @ file:///tmp/build/80754af9/partd_1618000087440/work 140 | path @ file:///tmp/build/80754af9/path_1614022220526/work 141 | pathlib2 @ file:///tmp/build/80754af9/pathlib2_1607024983162/work 142 | pathspec==0.7.0 143 | pathtools==0.1.2 144 | patsy==0.5.1 145 | pep8==1.7.1 146 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work 147 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work 148 | Pillow @ file:///tmp/build/80754af9/pillow_1617383569452/work 149 | pkginfo==1.7.0 150 | pluggy @ file:///tmp/build/80754af9/pluggy_1615976321666/work 151 | ply==3.11 152 | poyo @ file:///tmp/build/80754af9/poyo_1617751526755/work 153 | prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1618088486455/work 154 | prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work 155 | protobuf==3.17.0 156 | psutil @ file:///tmp/build/80754af9/psutil_1612298023621/work 157 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 158 | py @ file:///tmp/build/80754af9/py_1607971587848/work 159 | pycodestyle @ file:///home/ktietz/src/ci_mi/pycodestyle_1612807597675/work 160 | pycosat==0.6.3 161 | pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work 162 | pycurl==7.43.0.6 163 | pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1616182067796/work 164 | pyerfa @ file:///tmp/build/80754af9/pyerfa_1619390903914/work 165 | pyflakes @ file:///home/ktietz/src/ci_ipy2/pyflakes_1612551159640/work 166 | Pygments @ file:///tmp/build/80754af9/pygments_1615143339740/work 167 | pylint @ file:///tmp/build/80754af9/pylint_1617135829881/work 168 | pyls-black @ file:///tmp/build/80754af9/pyls-black_1607553132291/work 169 | pyls-spyder @ file:///tmp/build/80754af9/pyls-spyder_1613849700860/work 170 | pyodbc===4.0.0-unsupported 171 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1608057966937/work 172 | pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work 173 | pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work 174 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work 175 | pytest==6.2.3 176 | python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work 177 | python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work 178 | python-language-server @ file:///tmp/build/80754af9/python-language-server_1607972495879/work 179 | python-louvain==0.15 180 | python-slugify @ file:///tmp/build/80754af9/python-slugify_1620405669636/work 181 | pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work 182 | PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work 183 | pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work 184 | PyYAML==5.4.1 185 | pyzmq==20.0.0 186 | QDarkStyle @ file:///tmp/build/80754af9/qdarkstyle_1617386714626/work 187 | qstylizer @ file:///tmp/build/80754af9/qstylizer_1617713584600/work/dist/qstylizer-0.1.10-py2.py3-none-any.whl 188 | QtAwesome @ file:///tmp/build/80754af9/qtawesome_1615991616277/work 189 | qtconsole @ file:///tmp/build/80754af9/qtconsole_1616775094278/work 190 | QtPy==1.9.0 191 | rdflib==5.0.0 192 | regex @ file:///tmp/build/80754af9/regex_1617569202463/work 193 | requests @ file:///tmp/build/80754af9/requests_1608241421344/work 194 | rope @ file:///tmp/build/80754af9/rope_1602264064449/work 195 | Rtree @ file:///tmp/build/80754af9/rtree_1618420845272/work 196 | ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016699510/work 197 | scikit-image==0.16.2 198 | scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1614446682169/work 199 | scipy @ file:///tmp/build/80754af9/scipy_1618855647378/work 200 | seaborn @ file:///tmp/build/80754af9/seaborn_1608578541026/work 201 | SecretStorage @ file:///tmp/build/80754af9/secretstorage_1614022784285/work 202 | Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work 203 | simplegeneric==0.8.1 204 | singledispatch @ file:///tmp/build/80754af9/singledispatch_1614366001199/work 205 | sip==4.19.13 206 | six @ file:///tmp/build/80754af9/six_1605205327372/work 207 | sniffio @ file:///tmp/build/80754af9/sniffio_1614030475067/work 208 | snowballstemmer @ file:///tmp/build/80754af9/snowballstemmer_1611258885636/work 209 | sortedcollections @ file:///tmp/build/80754af9/sortedcollections_1611172717284/work 210 | sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1606865132123/work 211 | soupsieve @ file:///tmp/build/80754af9/soupsieve_1616183228191/work 212 | Sphinx @ file:///tmp/build/80754af9/sphinx_1616268783226/work 213 | sphinxcontrib-applehelp @ file:///home/ktietz/src/ci/sphinxcontrib-applehelp_1611920841464/work 214 | sphinxcontrib-devhelp @ file:///home/ktietz/src/ci/sphinxcontrib-devhelp_1611920923094/work 215 | sphinxcontrib-htmlhelp @ file:///home/ktietz/src/ci/sphinxcontrib-htmlhelp_1611920974801/work 216 | sphinxcontrib-jsmath @ file:///home/ktietz/src/ci/sphinxcontrib-jsmath_1611920942228/work 217 | sphinxcontrib-qthelp @ file:///home/ktietz/src/ci/sphinxcontrib-qthelp_1611921055322/work 218 | sphinxcontrib-serializinghtml @ file:///home/ktietz/src/ci/sphinxcontrib-serializinghtml_1611920755253/work 219 | sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work 220 | spyder @ file:///tmp/build/80754af9/spyder_1618327905127/work 221 | spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1617396566288/work 222 | SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1618089170652/work 223 | statsmodels @ file:///tmp/build/80754af9/statsmodels_1614023746358/work 224 | sympy @ file:///tmp/build/80754af9/sympy_1618252284338/work 225 | tables==3.6.1 226 | tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work 227 | tensorboardX==2.2 228 | terminado==0.9.4 229 | testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work 230 | text-unidecode==1.3 231 | textdistance @ file:///tmp/build/80754af9/textdistance_1612461398012/work 232 | threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl 233 | three-merge @ file:///tmp/build/80754af9/three-merge_1607553261110/work 234 | tinycss @ file:///tmp/build/80754af9/tinycss_1617713798712/work 235 | toml @ file:///tmp/build/80754af9/toml_1616166611790/work 236 | toolz @ file:///home/linux1/recipes/ci/toolz_1610987900194/work 237 | torch==1.7.0 238 | torch-cluster==1.5.9 239 | torch-geometric==1.7.0 240 | torch-scatter==2.0.6 241 | torch-sparse==0.6.9 242 | torch-spline-conv==1.2.1 243 | torchaudio==0.7.0a0+ac17b64 244 | torchvision==0.8.0 245 | tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work 246 | tqdm @ file:///tmp/build/80754af9/tqdm_1615925068909/work 247 | traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work 248 | tsBNgen==1.0.0 249 | typed-ast @ file:///tmp/build/80754af9/typed-ast_1610484547928/work 250 | typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work 251 | ujson @ file:///tmp/build/80754af9/ujson_1611259522456/work 252 | unicodecsv==0.14.1 253 | Unidecode @ file:///tmp/build/80754af9/unidecode_1614712377438/work 254 | urllib3 @ file:///tmp/build/80754af9/urllib3_1615837158687/work 255 | watchdog @ file:///tmp/build/80754af9/watchdog_1612471027849/work 256 | wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work 257 | webencodings==0.5.1 258 | Werkzeug @ file:///home/ktietz/src/ci/werkzeug_1611932622770/work 259 | whichcraft @ file:///tmp/build/80754af9/whichcraft_1617751293875/work 260 | widgetsnbextension==3.5.1 261 | wrapt==1.12.1 262 | wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1617224664226/work 263 | xlrd @ file:///tmp/build/80754af9/xlrd_1608072521494/work 264 | XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1617224712951/work 265 | xlwt==1.3.0 266 | yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work 267 | zict==2.0.0 268 | zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work 269 | zope.event==4.5.0 270 | zope.interface @ file:///tmp/build/80754af9/zope.interface_1616357211867/work 271 | --------------------------------------------------------------------------------