├── .idea
├── .gitignore
├── GNN_biomarker_MEDIA.iml
├── deployment.xml
├── encodings.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
└── webServers.xml
├── 01-fetch_data.py
├── 02-process_data.py
├── 03-main.py
├── README.md
├── imports
├── ABIDEDataset.py
├── __inits__.py
├── __pycache__
│ ├── preprocess_data.cpython-36.pyc
│ └── preprocess_data.cpython-38.pyc
├── gdc.py
├── preprocess_data.py
├── read_abide_stats_parall.py
└── utils.py
├── net
├── braingnn.py
├── braingraphconv.py
├── brainmsgpassing.py
└── inits.py
└── requirements.txt
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/GNN_biomarker_MEDIA.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
211 |
212 |
213 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
13 |
14 |
--------------------------------------------------------------------------------
/01-fetch_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Mwiza Kunda
2 | # Copyright (C) 2017 Sarah Parisot , , Sofia Ira Ktena
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see .
16 |
17 | from nilearn import datasets
18 | import argparse
19 | from imports import preprocess_data as Reader
20 | import os
21 | import shutil
22 | import sys
23 |
24 | # Input data variables
25 | code_folder = os.getcwd()
26 | root_folder = '/home/azureuser/projects/BrainGNN/data/'
27 | data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/')
28 | if not os.path.exists(data_folder):
29 | os.makedirs(data_folder)
30 | shutil.copyfile(os.path.join(root_folder,'subject_ID.txt'), os.path.join(data_folder, 'subject_IDs.txt'))
31 |
32 | def str2bool(v):
33 | if isinstance(v, bool):
34 | return v
35 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
36 | return True
37 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
38 | return False
39 | else:
40 | raise argparse.ArgumentTypeError('Boolean value expected.')
41 |
42 |
43 | def main():
44 | parser = argparse.ArgumentParser(description='Download ABIDE data and compute functional connectivity matrices')
45 | parser.add_argument('--pipeline', default='cpac', type=str,
46 | help='Pipeline to preprocess ABIDE data. Available options are ccs, cpac, dparsf and niak.'
47 | ' default: cpac.')
48 | parser.add_argument('--atlas', default='cc200',
49 | help='Brain parcellation atlas. Options: ho, cc200 and cc400, default: cc200.')
50 | parser.add_argument('--download', default=True, type=str2bool,
51 | help='Dowload data or just compute functional connectivity. default: True')
52 | args = parser.parse_args()
53 | print(args)
54 |
55 | params = dict()
56 |
57 | pipeline = args.pipeline
58 | atlas = args.atlas
59 | download = args.download
60 |
61 | # Files to fetch
62 |
63 | files = ['rois_' + atlas]
64 |
65 | filemapping = {'func_preproc': 'func_preproc.nii.gz',
66 | files[0]: files[0] + '.1D'}
67 |
68 |
69 | # Download database files
70 | if download == True:
71 | abide = datasets.fetch_abide_pcp(data_dir=root_folder, pipeline=pipeline,
72 | band_pass_filtering=True, global_signal_regression=False, derivatives=files,
73 | quality_checked=False)
74 |
75 | subject_IDs = Reader.get_ids() #changed path to data path
76 | subject_IDs = subject_IDs.tolist()
77 |
78 | # Create a folder for each subject
79 | for s, fname in zip(subject_IDs, Reader.fetch_filenames(subject_IDs, files[0], atlas)):
80 | subject_folder = os.path.join(data_folder, s)
81 | if not os.path.exists(subject_folder):
82 | os.mkdir(subject_folder)
83 |
84 | # Get the base filename for each subject
85 | base = fname.split(files[0])[0]
86 |
87 | # Move each subject file to the subject folder
88 | for fl in files:
89 | if not os.path.exists(os.path.join(subject_folder, base + filemapping[fl])):
90 | shutil.move(base + filemapping[fl], subject_folder)
91 |
92 | time_series = Reader.get_timeseries(subject_IDs, atlas)
93 |
94 | # Compute and save connectivity matrices
95 | Reader.subject_connectivity(time_series, subject_IDs, atlas, 'correlation')
96 | Reader.subject_connectivity(time_series, subject_IDs, atlas, 'partial correlation')
97 |
98 |
99 | if __name__ == '__main__':
100 | main()
--------------------------------------------------------------------------------
/02-process_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Mwiza Kunda
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 |
17 | import sys
18 | import argparse
19 | import pandas as pd
20 | import numpy as np
21 | from imports import preprocess_data as Reader
22 | import deepdish as dd
23 | import warnings
24 | import os
25 |
26 | warnings.filterwarnings("ignore")
27 | root_folder = '/home/azureuser/projects/BrainGNN/data/'
28 | data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/')
29 |
30 | # Process boolean command line arguments
31 | def str2bool(v):
32 | if isinstance(v, bool):
33 | return v
34 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
35 | return True
36 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
37 | return False
38 | else:
39 | raise argparse.ArgumentTypeError('Boolean value expected.')
40 |
41 |
42 | def main():
43 | parser = argparse.ArgumentParser(description='Classification of the ABIDE dataset using a Ridge classifier. '
44 | 'MIDA is used to minimize the distribution mismatch between ABIDE sites')
45 | parser.add_argument('--atlas', default='cc200',
46 | help='Atlas for network construction (node definition) options: ho, cc200, cc400, default: cc200.')
47 | parser.add_argument('--seed', default=123, type=int, help='Seed for random initialisation. default: 1234.')
48 | parser.add_argument('--nclass', default=2, type=int, help='Number of classes. default:2')
49 |
50 |
51 | args = parser.parse_args()
52 | print('Arguments: \n', args)
53 |
54 |
55 | params = dict()
56 |
57 | params['seed'] = args.seed # seed for random initialisation
58 |
59 | # Algorithm choice
60 | params['atlas'] = args.atlas # Atlas for network construction
61 | atlas = args.atlas # Atlas for network construction (node definition)
62 |
63 | # Get subject IDs and class labels
64 | subject_IDs = Reader.get_ids()
65 | labels = Reader.get_subject_score(subject_IDs, score='DX_GROUP')
66 |
67 | # Number of subjects and classes for binary classification
68 | num_classes = args.nclass
69 | num_subjects = len(subject_IDs)
70 | params['n_subjects'] = num_subjects
71 |
72 | # Initialise variables for class labels and acquisition sites
73 | # 1 is autism, 2 is control
74 | y_data = np.zeros([num_subjects, num_classes]) # n x 2
75 | y = np.zeros([num_subjects, 1]) # n x 1
76 |
77 | # Get class labels for all subjects
78 | for i in range(num_subjects):
79 | y_data[i, int(labels[subject_IDs[i]]) - 1] = 1
80 | y[i] = int(labels[subject_IDs[i]])
81 |
82 | # Compute feature vectors (vectorised connectivity networks)
83 | fea_corr = Reader.get_networks(subject_IDs, iter_no='', kind='correlation', atlas_name=atlas) #(1035, 200, 200)
84 | fea_pcorr = Reader.get_networks(subject_IDs, iter_no='', kind='partial correlation', atlas_name=atlas) #(1035, 200, 200)
85 |
86 | if not os.path.exists(os.path.join(data_folder,'raw')):
87 | os.makedirs(os.path.join(data_folder,'raw'))
88 | for i, subject in enumerate(subject_IDs):
89 | dd.io.save(os.path.join(data_folder,'raw',subject+'.h5'),{'corr':fea_corr[i],'pcorr':fea_pcorr[i],'label':y[i]%2})
90 |
91 | if __name__ == '__main__':
92 | main()
93 |
--------------------------------------------------------------------------------
/03-main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 | import time
5 | import copy
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.optim import lr_scheduler
10 | from tensorboardX import SummaryWriter
11 |
12 | from imports.ABIDEDataset import ABIDEDataset
13 | from torch_geometric.data import DataLoader
14 | from net.braingnn import Network
15 | from imports.utils import train_val_test_split
16 | from sklearn.metrics import classification_report, confusion_matrix
17 |
18 | torch.manual_seed(123)
19 |
20 | EPS = 1e-10
21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
26 | parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training')
27 | parser.add_argument('--batchSize', type=int, default=100, help='size of the batches')
28 | parser.add_argument('--dataroot', type=str, default='/home/azureuser/projects/BrainGNN/data/ABIDE_pcp/cpac/filt_noglobal', help='root directory of the dataset')
29 | parser.add_argument('--fold', type=int, default=0, help='training which fold')
30 | parser.add_argument('--lr', type = float, default=0.01, help='learning rate')
31 | parser.add_argument('--stepsize', type=int, default=20, help='scheduler step size')
32 | parser.add_argument('--gamma', type=float, default=0.5, help='scheduler shrinking rate')
33 | parser.add_argument('--weightdecay', type=float, default=5e-3, help='regularization')
34 | parser.add_argument('--lamb0', type=float, default=1, help='classification loss weight')
35 | parser.add_argument('--lamb1', type=float, default=0, help='s1 unit regularization')
36 | parser.add_argument('--lamb2', type=float, default=0, help='s2 unit regularization')
37 | parser.add_argument('--lamb3', type=float, default=0.1, help='s1 entropy regularization')
38 | parser.add_argument('--lamb4', type=float, default=0.1, help='s2 entropy regularization')
39 | parser.add_argument('--lamb5', type=float, default=0, help='s1 consistence regularization')
40 | parser.add_argument('--layer', type=int, default=2, help='number of GNN layers')
41 | parser.add_argument('--ratio', type=float, default=0.5, help='pooling ratio')
42 | parser.add_argument('--indim', type=int, default=200, help='feature dim')
43 | parser.add_argument('--nroi', type=int, default=200, help='num of ROIs')
44 | parser.add_argument('--nclass', type=int, default=2, help='num of classes')
45 | parser.add_argument('--load_model', type=bool, default=False)
46 | parser.add_argument('--save_model', type=bool, default=True)
47 | parser.add_argument('--optim', type=str, default='SGD', help='optimization method: SGD, Adam')
48 | parser.add_argument('--save_path', type=str, default='./model/', help='path to save model')
49 | opt = parser.parse_args()
50 |
51 | if not os.path.exists(opt.save_path):
52 | os.makedirs(opt.save_path)
53 |
54 | #################### Parameter Initialization #######################
55 | path = opt.dataroot
56 | name = 'ABIDE'
57 | save_model = opt.save_model
58 | load_model = opt.load_model
59 | opt_method = opt.optim
60 | num_epoch = opt.n_epochs
61 | fold = opt.fold
62 | writer = SummaryWriter(os.path.join('./log',str(fold)))
63 |
64 |
65 |
66 | ################## Define Dataloader ##################################
67 |
68 | dataset = ABIDEDataset(path,name)
69 | dataset.data.y = dataset.data.y.squeeze()
70 | dataset.data.x[dataset.data.x == float('inf')] = 0
71 |
72 | tr_index,val_index,te_index = train_val_test_split(fold=fold)
73 | train_mask = torch.zeros(len(dataset), dtype=torch.uint8)
74 | val_mask = torch.zeros(len(dataset), dtype=torch.uint8)
75 | test_mask = torch.zeros(len(dataset), dtype=torch.uint8)
76 | train_mask[tr_index] = 1
77 | val_mask[val_index] = 1
78 | test_mask[te_index] = 1
79 | train_dataset = dataset[train_mask]
80 | val_dataset = dataset[val_mask]
81 | test_dataset = dataset[test_mask]
82 |
83 |
84 | train_loader = DataLoader(train_dataset,batch_size=opt.batchSize, shuffle= True)
85 | val_loader = DataLoader(val_dataset, batch_size=opt.batchSize, shuffle=False)
86 | test_loader = DataLoader(test_dataset, batch_size=opt.batchSize, shuffle=False)
87 |
88 |
89 |
90 | ############### Define Graph Deep Learning Network ##########################
91 | model = Network(opt.indim,opt.ratio,opt.nclass).to(device)
92 | print(model)
93 |
94 | if opt_method == 'Adam':
95 | optimizer = torch.optim.Adam(model.parameters(), lr= opt.lr, weight_decay=opt.weightdecay)
96 | elif opt_method == 'SGD':
97 | optimizer = torch.optim.SGD(model.parameters(), lr =opt.lr, momentum = 0.9, weight_decay=opt.weightdecay, nesterov = True)
98 |
99 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.stepsize, gamma=opt.gamma)
100 |
101 | ############################### Define Other Loss Functions ########################################
102 | def topk_loss(s,ratio):
103 | if ratio > 0.5:
104 | ratio = 1-ratio
105 | s = s.sort(dim=1).values
106 | res = -torch.log(s[:,-int(s.size(1)*ratio):]+EPS).mean() -torch.log(1-s[:,:int(s.size(1)*ratio)]+EPS).mean()
107 | return res
108 |
109 |
110 | def consist_loss(s):
111 | if len(s) == 0:
112 | return 0
113 | s = torch.sigmoid(s)
114 | W = torch.ones(s.shape[0],s.shape[0])
115 | D = torch.eye(s.shape[0])*torch.sum(W,dim=1)
116 | L = D-W
117 | L = L.to(device)
118 | res = torch.trace(torch.transpose(s,0,1) @ L @ s)/(s.shape[0]*s.shape[0])
119 | return res
120 |
121 | ###################### Network Training Function#####################################
122 | def train(epoch):
123 | print('train...........')
124 | scheduler.step()
125 |
126 | for param_group in optimizer.param_groups:
127 | print("LR", param_group['lr'])
128 | model.train()
129 | s1_list = []
130 | s2_list = []
131 | loss_all = 0
132 | step = 0
133 | for data in train_loader:
134 | data = data.to(device)
135 | optimizer.zero_grad()
136 | output, w1, w2, s1, s2 = model(data.x, data.edge_index, data.batch, data.edge_attr, data.pos)
137 | s1_list.append(s1.view(-1).detach().cpu().numpy())
138 | s2_list.append(s2.view(-1).detach().cpu().numpy())
139 |
140 | loss_c = F.nll_loss(output, data.y)
141 |
142 | loss_p1 = (torch.norm(w1, p=2)-1) ** 2
143 | loss_p2 = (torch.norm(w2, p=2)-1) ** 2
144 | loss_tpk1 = topk_loss(s1,opt.ratio)
145 | loss_tpk2 = topk_loss(s2,opt.ratio)
146 | loss_consist = 0
147 | for c in range(opt.nclass):
148 | loss_consist += consist_loss(s1[data.y == c])
149 | loss = opt.lamb0*loss_c + opt.lamb1 * loss_p1 + opt.lamb2 * loss_p2 \
150 | + opt.lamb3 * loss_tpk1 + opt.lamb4 *loss_tpk2 + opt.lamb5* loss_consist
151 | writer.add_scalar('train/classification_loss', loss_c, epoch*len(train_loader)+step)
152 | writer.add_scalar('train/unit_loss1', loss_p1, epoch*len(train_loader)+step)
153 | writer.add_scalar('train/unit_loss2', loss_p2, epoch*len(train_loader)+step)
154 | writer.add_scalar('train/TopK_loss1', loss_tpk1, epoch*len(train_loader)+step)
155 | writer.add_scalar('train/TopK_loss2', loss_tpk2, epoch*len(train_loader)+step)
156 | writer.add_scalar('train/GCL_loss', loss_consist, epoch*len(train_loader)+step)
157 | step = step + 1
158 |
159 | loss.backward()
160 | loss_all += loss.item() * data.num_graphs
161 | optimizer.step()
162 |
163 | s1_arr = np.hstack(s1_list)
164 | s2_arr = np.hstack(s2_list)
165 | return loss_all / len(train_dataset), s1_arr, s2_arr ,w1,w2
166 |
167 |
168 | ###################### Network Testing Function#####################################
169 | def test_acc(loader):
170 | model.eval()
171 | correct = 0
172 | for data in loader:
173 | data = data.to(device)
174 | outputs= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos)
175 | pred = outputs[0].max(dim=1)[1]
176 | correct += pred.eq(data.y).sum().item()
177 |
178 | return correct / len(loader.dataset)
179 |
180 | def test_loss(loader,epoch):
181 | print('testing...........')
182 | model.eval()
183 | loss_all = 0
184 | for data in loader:
185 | data = data.to(device)
186 | output, w1, w2, s1, s2= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos)
187 | loss_c = F.nll_loss(output, data.y)
188 |
189 | loss_p1 = (torch.norm(w1, p=2)-1) ** 2
190 | loss_p2 = (torch.norm(w2, p=2)-1) ** 2
191 | loss_tpk1 = topk_loss(s1,opt.ratio)
192 | loss_tpk2 = topk_loss(s2,opt.ratio)
193 | loss_consist = 0
194 | for c in range(opt.nclass):
195 | loss_consist += consist_loss(s1[data.y == c])
196 | loss = opt.lamb0*loss_c + opt.lamb1 * loss_p1 + opt.lamb2 * loss_p2 \
197 | + opt.lamb3 * loss_tpk1 + opt.lamb4 *loss_tpk2 + opt.lamb5* loss_consist
198 |
199 | loss_all += loss.item() * data.num_graphs
200 | return loss_all / len(loader.dataset)
201 |
202 | #######################################################################################
203 | ############################ Model Training #########################################
204 | #######################################################################################
205 | best_model_wts = copy.deepcopy(model.state_dict())
206 | best_loss = 1e10
207 | for epoch in range(0, num_epoch):
208 | since = time.time()
209 | tr_loss, s1_arr, s2_arr, w1, w2 = train(epoch)
210 | tr_acc = test_acc(train_loader)
211 | val_acc = test_acc(val_loader)
212 | val_loss = test_loss(val_loader,epoch)
213 | time_elapsed = time.time() - since
214 | print('*====**')
215 | print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
216 | print('Epoch: {:03d}, Train Loss: {:.7f}, '
217 | 'Train Acc: {:.7f}, Test Loss: {:.7f}, Test Acc: {:.7f}'.format(epoch, tr_loss,
218 | tr_acc, val_loss, val_acc))
219 |
220 | writer.add_scalars('Acc',{'train_acc':tr_acc,'val_acc':val_acc}, epoch)
221 | writer.add_scalars('Loss', {'train_loss': tr_loss, 'val_loss': val_loss}, epoch)
222 | writer.add_histogram('Hist/hist_s1', s1_arr, epoch)
223 | writer.add_histogram('Hist/hist_s2', s2_arr, epoch)
224 |
225 | if val_loss < best_loss and epoch > 5:
226 | print("saving best model")
227 | best_loss = val_loss
228 | best_model_wts = copy.deepcopy(model.state_dict())
229 | if save_model:
230 | torch.save(best_model_wts, os.path.join(opt.save_path,str(fold)+'.pth'))
231 |
232 | #######################################################################################
233 | ######################### Testing on testing set ######################################
234 | #######################################################################################
235 |
236 | if opt.load_model:
237 | model = Network(opt.indim,opt.ratio,opt.nclass).to(device)
238 | model.load_state_dict(torch.load(os.path.join(opt.save_path,str(fold)+'.pth')))
239 | model.eval()
240 | preds = []
241 | correct = 0
242 | for data in val_loader:
243 | data = data.to(device)
244 | outputs= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos)
245 | pred = outputs[0].max(1)[1]
246 | preds.append(pred.cpu().detach().numpy())
247 | correct += pred.eq(data.y).sum().item()
248 | preds = np.concatenate(preds,axis=0)
249 | trues = val_dataset.data.y.cpu().detach().numpy()
250 | cm = confusion_matrix(trues,preds)
251 | print("Confusion matrix")
252 | print(classification_report(trues, preds))
253 |
254 | else:
255 | model.load_state_dict(best_model_wts)
256 | model.eval()
257 | test_accuracy = test_acc(test_loader)
258 | test_l= test_loss(test_loader,0)
259 | print("===========================")
260 | print("Test Acc: {:.7f}, Test Loss: {:.7f} ".format(test_accuracy, test_l))
261 | print(opt)
262 |
263 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Graph Neural Network for Brain Network Analysis
2 | A preliminary implementation of BrainGNN
3 |
4 |
5 | ## Usage
6 | ### Setup
7 | **pip**
8 |
9 | See the `requirements.txt` for environment configuration.
10 | ```bash
11 | pip install -r requirements.txt
12 | ```
13 | **PYG**
14 |
15 | To install pyg library, [please refer to the document](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)
16 |
17 | ### Dataset
18 | **ABIDE**
19 |
20 | We treat each fMRI as a brain graph. How to download and construct the graphs?
21 | ```
22 | python 01-fetch_data.py
23 | python 02-process_data.py
24 | ```
25 |
26 | ### How to run classification?
27 | Training and testing are integrated in file `main.py`. To run
28 | ```
29 | python 03-main.py
30 | ```
31 |
32 |
33 | ## Citation
34 | If you find the code and dataset useful, please cite our paper.
35 | ```latex
36 | @article{li2020braingnn,
37 | title={Braingnn: Interpretable brain graph neural network for fmri analysis},
38 | author={Li, Xiaoxiao and Zhou,Yuan and Dvornek, Nicha and Zhang, Muhan and Gao, Siyuan and Zhuang, Juntang and Scheinost, Dustin and Staib, Lawrence and Ventola, Pamela and Duncan, James},
39 | journal={bioRxiv},
40 | year={2020},
41 | publisher={Cold Spring Harbor Laboratory}
42 | }
43 | ```
44 |
--------------------------------------------------------------------------------
/imports/ABIDEDataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import InMemoryDataset,Data
3 | from os.path import join, isfile
4 | from os import listdir
5 | import numpy as np
6 | import os.path as osp
7 | from imports.read_abide_stats_parall import read_data
8 |
9 |
10 | class ABIDEDataset(InMemoryDataset):
11 | def __init__(self, root, name, transform=None, pre_transform=None):
12 | self.root = root
13 | self.name = name
14 | super(ABIDEDataset, self).__init__(root,transform, pre_transform)
15 | self.data, self.slices = torch.load(self.processed_paths[0])
16 |
17 | @property
18 | def raw_file_names(self):
19 | data_dir = osp.join(self.root,'raw')
20 | onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))]
21 | onlyfiles.sort()
22 | return onlyfiles
23 | @property
24 | def processed_file_names(self):
25 | return 'data.pt'
26 |
27 | def download(self):
28 | # Download to `self.raw_dir`.
29 | return
30 |
31 | def process(self):
32 | # Read data into huge `Data` list.
33 | self.data, self.slices = read_data(self.raw_dir)
34 |
35 | if self.pre_filter is not None:
36 | data_list = [self.get(idx) for idx in range(len(self))]
37 | data_list = [data for data in data_list if self.pre_filter(data)]
38 | self.data, self.slices = self.collate(data_list)
39 |
40 | if self.pre_transform is not None:
41 | data_list = [self.get(idx) for idx in range(len(self))]
42 | data_list = [self.pre_transform(data) for data in data_list]
43 | self.data, self.slices = self.collate(data_list)
44 |
45 | torch.save((self.data, self.slices), self.processed_paths[0])
46 |
47 | def __repr__(self):
48 | return '{}({})'.format(self.name, len(self))
49 |
--------------------------------------------------------------------------------
/imports/__inits__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LifangHe/BrainGNN_Pytorch/30b78ae28a1e8d6d23004884b6c8e7010bcaf587/imports/__inits__.py
--------------------------------------------------------------------------------
/imports/__pycache__/preprocess_data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LifangHe/BrainGNN_Pytorch/30b78ae28a1e8d6d23004884b6c8e7010bcaf587/imports/__pycache__/preprocess_data.cpython-36.pyc
--------------------------------------------------------------------------------
/imports/__pycache__/preprocess_data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LifangHe/BrainGNN_Pytorch/30b78ae28a1e8d6d23004884b6c8e7010bcaf587/imports/__pycache__/preprocess_data.cpython-38.pyc
--------------------------------------------------------------------------------
/imports/gdc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numba
3 | import numpy as np
4 | from scipy.linalg import expm
5 | from torch_geometric.utils import add_self_loops, is_undirected, to_dense_adj
6 | from torch_sparse import coalesce
7 | from torch_scatter import scatter_add
8 |
9 |
10 | def jit():
11 | def decorator(func):
12 | try:
13 | return numba.jit(cache=True)(func)
14 | except RuntimeError:
15 | return numba.jit(cache=False)(func)
16 |
17 | return decorator
18 |
19 |
20 | class GDC(object):
21 | r"""Processes the graph via Graph Diffusion Convolution (GDC) from the
22 | `"Diffusion Improves Graph Learning" `_
23 | paper.
24 | .. note::
25 | The paper offers additional advice on how to choose the
26 | hyperparameters.
27 | For an example of using GCN with GDC, see `examples/gcn.py
28 | `_.
30 | Args:
31 | self_loop_weight (float, optional): Weight of the added self-loop.
32 | Set to :obj:`None` to add no self-loops. (default: :obj:`1`)
33 | normalization_in (str, optional): Normalization of the transition
34 | matrix on the original (input) graph. Possible values:
35 | :obj:`"sym"`, :obj:`"col"`, and :obj:`"row"`.
36 | See :func:`GDC.transition_matrix` for details.
37 | (default: :obj:`"sym"`)
38 | normalization_out (str, optional): Normalization of the transition
39 | matrix on the transformed GDC (output) graph. Possible values:
40 | :obj:`"sym"`, :obj:`"col"`, :obj:`"row"`, and :obj:`None`.
41 | See :func:`GDC.transition_matrix` for details.
42 | (default: :obj:`"col"`)
43 | diffusion_kwargs (dict, optional): Dictionary containing the parameters
44 | for diffusion.
45 | `method` specifies the diffusion method (:obj:`"ppr"`,
46 | :obj:`"heat"` or :obj:`"coeff"`).
47 | Each diffusion method requires different additional parameters.
48 | See :func:`GDC.diffusion_matrix_exact` or
49 | :func:`GDC.diffusion_matrix_approx` for details.
50 | (default: :obj:`dict(method='ppr', alpha=0.15)`)
51 | sparsification_kwargs (dict, optional): Dictionary containing the
52 | parameters for sparsification.
53 | `method` specifies the sparsification method (:obj:`"threshold"` or
54 | :obj:`"topk"`).
55 | Each sparsification method requires different additional
56 | parameters.
57 | See :func:`GDC.sparsify_dense` for details.
58 | (default: :obj:`dict(method='threshold', avg_degree=64)`)
59 | exact (bool, optional): Whether to exactly calculate the diffusion
60 | matrix.
61 | Note that the exact variants are not scalable.
62 | They densify the adjacency matrix and calculate either its inverse
63 | or its matrix exponential.
64 | However, the approximate variants do not support edge weights and
65 | currently only personalized PageRank and sparsification by
66 | threshold are implemented as fast, approximate versions.
67 | (default: :obj:`True`)
68 | :rtype: :class:`torch_geometric.data.Data`
69 | """
70 | def __init__(self, self_loop_weight=1, normalization_in='sym',
71 | normalization_out='col',
72 | diffusion_kwargs=dict(method='ppr', alpha=0.15),
73 | sparsification_kwargs=dict(method='threshold',
74 | avg_degree=64), exact=True):
75 | self.self_loop_weight = self_loop_weight
76 | self.normalization_in = normalization_in
77 | self.normalization_out = normalization_out
78 | self.diffusion_kwargs = diffusion_kwargs
79 | self.sparsification_kwargs = sparsification_kwargs
80 | self.exact = exact
81 |
82 | if self_loop_weight:
83 | assert exact or self_loop_weight == 1
84 |
85 | @torch.no_grad()
86 | def __call__(self, data):
87 | N = data.num_nodes
88 | edge_index = data.edge_index
89 | if data.edge_attr is None:
90 | edge_weight = torch.ones(edge_index.size(1),
91 | device=edge_index.device)
92 | else:
93 | edge_weight = data.edge_attr
94 | assert self.exact
95 | assert edge_weight.dim() == 1
96 |
97 | if self.self_loop_weight:
98 | edge_index, edge_weight = add_self_loops(
99 | edge_index, edge_weight, fill_value=self.self_loop_weight,
100 | num_nodes=N)
101 |
102 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N)
103 |
104 | if self.exact:
105 | edge_index, edge_weight = self.transition_matrix(
106 | edge_index, edge_weight, N, self.normalization_in)
107 | diff_mat = self.diffusion_matrix_exact(edge_index, edge_weight, N,
108 | **self.diffusion_kwargs)
109 | edge_index, edge_weight = self.sparsify_dense(
110 | diff_mat, **self.sparsification_kwargs)
111 | else:
112 | edge_index, edge_weight = self.diffusion_matrix_approx(
113 | edge_index, edge_weight, N, self.normalization_in,
114 | **self.diffusion_kwargs)
115 | edge_index, edge_weight = self.sparsify_sparse(
116 | edge_index, edge_weight, N, **self.sparsification_kwargs)
117 |
118 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N)
119 | edge_index, edge_weight = self.transition_matrix(
120 | edge_index, edge_weight, N, self.normalization_out)
121 |
122 | data.edge_index = edge_index
123 | data.edge_attr = edge_weight
124 |
125 | return data
126 |
127 | def transition_matrix(self, edge_index, edge_weight, num_nodes,
128 | normalization):
129 | r"""Calculate the approximate, sparse diffusion on a given sparse
130 | matrix.
131 | Args:
132 | edge_index (LongTensor): The edge indices.
133 | edge_weight (Tensor): One-dimensional edge weights.
134 | num_nodes (int): Number of nodes.
135 | normalization (str): Normalization scheme:
136 | 1. :obj:`"sym"`: Symmetric normalization
137 | :math:`\mathbf{T} = \mathbf{D}^{-1/2} \mathbf{A}
138 | \mathbf{D}^{-1/2}`.
139 | 2. :obj:`"col"`: Column-wise normalization
140 | :math:`\mathbf{T} = \mathbf{A} \mathbf{D}^{-1}`.
141 | 3. :obj:`"row"`: Row-wise normalization
142 | :math:`\mathbf{T} = \mathbf{D}^{-1} \mathbf{A}`.
143 | 4. :obj:`None`: No normalization.
144 | :rtype: (:class:`LongTensor`, :class:`Tensor`)
145 | """
146 | if normalization == 'sym':
147 | row, col = edge_index
148 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
149 | deg_inv_sqrt = deg.pow(-0.5)
150 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
151 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
152 | elif normalization == 'col':
153 | _, col = edge_index
154 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
155 | deg_inv = 1. / deg
156 | deg_inv[deg_inv == float('inf')] = 0
157 | edge_weight = edge_weight * deg_inv[col]
158 | elif normalization == 'row':
159 | row, _ = edge_index
160 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
161 | deg_inv = 1. / deg
162 | deg_inv[deg_inv == float('inf')] = 0
163 | edge_weight = edge_weight * deg_inv[row]
164 | elif normalization is None:
165 | pass
166 | else:
167 | raise ValueError(
168 | 'Transition matrix normalization {} unknown.'.format(
169 | normalization))
170 |
171 | return edge_index, edge_weight
172 |
173 | def diffusion_matrix_exact(self, edge_index, edge_weight, num_nodes,
174 | method, **kwargs):
175 | r"""Calculate the (dense) diffusion on a given sparse graph.
176 | Note that these exact variants are not scalable. They densify the
177 | adjacency matrix and calculate either its inverse or its matrix
178 | exponential.
179 | Args:
180 | edge_index (LongTensor): The edge indices.
181 | edge_weight (Tensor): One-dimensional edge weights.
182 | num_nodes (int): Number of nodes.
183 | method (str): Diffusion method:
184 | 1. :obj:`"ppr"`: Use personalized PageRank as diffusion.
185 | Additionally expects the parameter:
186 | - **alpha** (*float*) - Return probability in PPR.
187 | Commonly lies in :obj:`[0.05, 0.2]`.
188 | 2. :obj:`"heat"`: Use heat kernel diffusion.
189 | Additionally expects the parameter:
190 | - **t** (*float*) - Time of diffusion. Commonly lies in
191 | :obj:`[2, 10]`.
192 | 3. :obj:`"coeff"`: Freely choose diffusion coefficients.
193 | Additionally expects the parameter:
194 | - **coeffs** (*List[float]*) - List of coefficients
195 | :obj:`theta_k` for each power of the transition matrix
196 | (starting at :obj:`0`).
197 | :rtype: (:class:`Tensor`)
198 | """
199 | if method == 'ppr':
200 | # α (I_n + (α - 1) A)^-1
201 | edge_weight = (kwargs['alpha'] - 1) * edge_weight
202 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
203 | fill_value=1,
204 | num_nodes=num_nodes)
205 | mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()
206 | diff_matrix = kwargs['alpha'] * torch.inverse(mat)
207 |
208 | elif method == 'heat':
209 | # exp(t (A - I_n))
210 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
211 | fill_value=-1,
212 | num_nodes=num_nodes)
213 | edge_weight = kwargs['t'] * edge_weight
214 | mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()
215 | undirected = is_undirected(edge_index, edge_weight, num_nodes)
216 | diff_matrix = self.__expm__(mat, undirected)
217 |
218 | elif method == 'coeff':
219 | adj_matrix = to_dense_adj(edge_index,
220 | edge_attr=edge_weight).squeeze()
221 | mat = torch.eye(num_nodes, device=edge_index.device)
222 |
223 | diff_matrix = kwargs['coeffs'][0] * mat
224 | for coeff in kwargs['coeffs'][1:]:
225 | mat = mat @ adj_matrix
226 | diff_matrix += coeff * mat
227 | else:
228 | raise ValueError('Exact GDC diffusion {} unknown.'.format(method))
229 |
230 | return diff_matrix
231 |
232 | def diffusion_matrix_approx(self, edge_index, edge_weight, num_nodes,
233 | normalization, method, **kwargs):
234 | r"""Calculate the approximate, sparse diffusion on a given sparse
235 | graph.
236 | Args:
237 | edge_index (LongTensor): The edge indices.
238 | edge_weight (Tensor): One-dimensional edge weights.
239 | num_nodes (int): Number of nodes.
240 | normalization (str): Transition matrix normalization scheme
241 | (:obj:`"sym"`, :obj:`"row"`, or :obj:`"col"`).
242 | See :func:`GDC.transition_matrix` for details.
243 | method (str): Diffusion method:
244 | 1. :obj:`"ppr"`: Use personalized PageRank as diffusion.
245 | Additionally expects the parameters:
246 | - **alpha** (*float*) - Return probability in PPR.
247 | Commonly lies in :obj:`[0.05, 0.2]`.
248 | - **eps** (*float*) - Threshold for PPR calculation stopping
249 | criterion (:obj:`edge_weight >= eps * out_degree`).
250 | Recommended default: :obj:`1e-4`.
251 | :rtype: (:class:`LongTensor`, :class:`Tensor`)
252 | """
253 | if method == 'ppr':
254 | if normalization == 'sym':
255 | # Calculate original degrees.
256 | _, col = edge_index
257 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
258 |
259 | edge_index_np = edge_index.cpu().numpy()
260 | # Assumes coalesced edge_index.
261 | _, indptr, out_degree = np.unique(edge_index_np[0],
262 | return_index=True,
263 | return_counts=True)
264 |
265 | neighbors, neighbor_weights = GDC.__calc_ppr__(
266 | indptr, edge_index_np[1], out_degree, kwargs['alpha'],
267 | kwargs['eps'])
268 | ppr_normalization = 'col' if normalization == 'col' else 'row'
269 | edge_index, edge_weight = self.__neighbors_to_graph__(
270 | neighbors, neighbor_weights, ppr_normalization,
271 | device=edge_index.device)
272 | edge_index = edge_index.to(torch.long)
273 |
274 | if normalization == 'sym':
275 | # We can change the normalization from row-normalized to
276 | # symmetric by multiplying the resulting matrix with D^{1/2}
277 | # from the left and D^{-1/2} from the right.
278 | # Since we use the original degrees for this it will be like
279 | # we had used symmetric normalization from the beginning
280 | # (except for errors due to approximation).
281 | row, col = edge_index
282 | deg_inv = deg.sqrt()
283 | deg_inv_sqrt = deg.pow(-0.5)
284 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
285 | edge_weight = deg_inv[row] * edge_weight * deg_inv_sqrt[col]
286 | elif normalization in ['col', 'row']:
287 | pass
288 | else:
289 | raise ValueError(
290 | ('Transition matrix normalization {} not implemented for '
291 | 'non-exact GDC computation.').format(normalization))
292 |
293 | elif method == 'heat':
294 | raise NotImplementedError(
295 | ('Currently no fast heat kernel is implemented. You are '
296 | 'welcome to create one yourself, e.g., based on '
297 | '"Kloster and Gleich: Heat kernel based community detection '
298 | '(KDD 2014)."'))
299 | else:
300 | raise ValueError(
301 | 'Approximate GDC diffusion {} unknown.'.format(method))
302 |
303 | return edge_index, edge_weight
304 |
305 | def sparsify_dense(self, matrix, method, **kwargs):
306 | r"""Sparsifies the given dense matrix.
307 | Args:
308 | matrix (Tensor): Matrix to sparsify.
309 | num_nodes (int): Number of nodes.
310 | method (str): Method of sparsification. Options:
311 | 1. :obj:`"threshold"`: Remove all edges with weights smaller
312 | than :obj:`eps`.
313 | Additionally expects one of these parameters:
314 | - **eps** (*float*) - Threshold to bound edges at.
315 | - **avg_degree** (*int*) - If :obj:`eps` is not given,
316 | it can optionally be calculated by calculating the
317 | :obj:`eps` required to achieve a given :obj:`avg_degree`.
318 | 2. :obj:`"topk"`: Keep edges with top :obj:`k` edge weights per
319 | node (column).
320 | Additionally expects the following parameters:
321 | - **k** (*int*) - Specifies the number of edges to keep.
322 | - **dim** (*int*) - The axis along which to take the top
323 | :obj:`k`.
324 | :rtype: (:class:`LongTensor`, :class:`Tensor`)
325 | """
326 | assert matrix.shape[0] == matrix.shape[1]
327 | N = matrix.shape[1]
328 |
329 | if method == 'threshold':
330 | if 'eps' not in kwargs.keys():
331 | kwargs['eps'] = self.__calculate_eps__(matrix, N,
332 | kwargs['avg_degree'])
333 |
334 | edge_index = torch.nonzero(matrix >= kwargs['eps']).t()
335 | edge_index_flat = edge_index[0] * N + edge_index[1]
336 | edge_weight = matrix.flatten()[edge_index_flat]
337 |
338 | elif method == 'topk':
339 | assert kwargs['dim'] in [0, 1]
340 | sort_idx = torch.argsort(matrix, dim=kwargs['dim'],
341 | descending=True)
342 | if kwargs['dim'] == 0:
343 | top_idx = sort_idx[:kwargs['k']]
344 | edge_weight = torch.gather(matrix, dim=kwargs['dim'],
345 | index=top_idx).flatten()
346 |
347 | row_idx = torch.arange(0, N, device=matrix.device).repeat(
348 | kwargs['k'])
349 | edge_index = torch.stack([top_idx.flatten(), row_idx], dim=0)
350 | else:
351 | top_idx = sort_idx[:, :kwargs['k']]
352 | edge_weight = torch.gather(matrix, dim=kwargs['dim'],
353 | index=top_idx).flatten()
354 |
355 | col_idx = torch.arange(
356 | 0, N, device=matrix.device).repeat_interleave(kwargs['k'])
357 | edge_index = torch.stack([col_idx, top_idx.flatten()], dim=0)
358 | else:
359 | raise ValueError('GDC sparsification {} unknown.'.format(method))
360 |
361 | return edge_index, edge_weight
362 |
363 | def sparsify_sparse(self, edge_index, edge_weight, num_nodes, method,
364 | **kwargs):
365 | r"""Sparsifies a given sparse graph further.
366 | Args:
367 | edge_index (LongTensor): The edge indices.
368 | edge_weight (Tensor): One-dimensional edge weights.
369 | num_nodes (int): Number of nodes.
370 | method (str): Method of sparsification:
371 | 1. :obj:`"threshold"`: Remove all edges with weights smaller
372 | than :obj:`eps`.
373 | Additionally expects one of these parameters:
374 | - **eps** (*float*) - Threshold to bound edges at.
375 | - **avg_degree** (*int*) - If :obj:`eps` is not given,
376 | it can optionally be calculated by calculating the
377 | :obj:`eps` required to achieve a given :obj:`avg_degree`.
378 | :rtype: (:class:`LongTensor`, :class:`Tensor`)
379 | """
380 | if method == 'threshold':
381 | if 'eps' not in kwargs.keys():
382 | kwargs['eps'] = self.__calculate_eps__(edge_weight, num_nodes,
383 | kwargs['avg_degree'])
384 |
385 | remaining_edge_idx = torch.nonzero(
386 | edge_weight >= kwargs['eps']).flatten()
387 | edge_index = edge_index[:, remaining_edge_idx]
388 | edge_weight = edge_weight[remaining_edge_idx]
389 | elif method == 'topk':
390 | raise NotImplementedError(
391 | 'Sparse topk sparsification not implemented.')
392 | else:
393 | raise ValueError('GDC sparsification {} unknown.'.format(method))
394 |
395 | return edge_index, edge_weight
396 |
397 | def __expm__(self, matrix, symmetric):
398 | r"""Calculates matrix exponential.
399 | Args:
400 | matrix (Tensor): Matrix to take exponential of.
401 | symmetric (bool): Specifies whether the matrix is symmetric.
402 | :rtype: (:class:`Tensor`)
403 | """
404 | if symmetric:
405 | e, V = torch.symeig(matrix, eigenvectors=True)
406 | diff_mat = V @ torch.diag(e.exp()) @ V.t()
407 | else:
408 | diff_mat_np = expm(matrix.cpu().numpy())
409 | diff_mat = torch.Tensor(diff_mat_np).to(matrix.device)
410 | return diff_mat
411 |
412 | def __calculate_eps__(self, matrix, num_nodes, avg_degree):
413 | r"""Calculates threshold necessary to achieve a given average degree.
414 | Args:
415 | matrix (Tensor): Adjacency matrix or edge weights.
416 | num_nodes (int): Number of nodes.
417 | avg_degree (int): Target average degree.
418 | :rtype: (:class:`float`)
419 | """
420 | sorted_edges = torch.sort(matrix.flatten(), descending=True).values
421 | if avg_degree * num_nodes > len(sorted_edges):
422 | return -np.inf
423 | return sorted_edges[avg_degree * num_nodes - 1]
424 |
425 | def __neighbors_to_graph__(self, neighbors, neighbor_weights,
426 | normalization='row', device='cpu'):
427 | r"""Combine a list of neighbors and neighbor weights to create a sparse
428 | graph.
429 | Args:
430 | neighbors (List[List[int]]): List of neighbors for each node.
431 | neighbor_weights (List[List[float]]): List of weights for the
432 | neighbors of each node.
433 | normalization (str): Normalization of resulting matrix
434 | (options: :obj:`"row"`, :obj:`"col"`). (default: :obj:`"row"`)
435 | device (torch.device): Device to create output tensors on.
436 | (default: :obj:`"cpu"`)
437 | :rtype: (:class:`LongTensor`, :class:`Tensor`)
438 | """
439 | edge_weight = torch.Tensor(np.concatenate(neighbor_weights)).to(device)
440 | i = np.repeat(np.arange(len(neighbors)),
441 | np.fromiter(map(len, neighbors), dtype=np.int))
442 | j = np.concatenate(neighbors)
443 | if normalization == 'col':
444 | edge_index = torch.Tensor(np.vstack([j, i])).to(device)
445 | N = len(neighbors)
446 | edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N)
447 | elif normalization == 'row':
448 | edge_index = torch.Tensor(np.vstack([i, j])).to(device)
449 | else:
450 | raise ValueError(
451 | f"PPR matrix normalization {normalization} unknown.")
452 | return edge_index, edge_weight
453 |
454 | @staticmethod
455 | @jit()
456 | def __calc_ppr__(indptr, indices, out_degree, alpha, eps):
457 | r"""Calculate the personalized PageRank vector for all nodes
458 | using a variant of the Andersen algorithm
459 | (see Andersen et al. :Local Graph Partitioning using PageRank Vectors.)
460 | Args:
461 | indptr (np.ndarray): Index pointer for the sparse matrix
462 | (CSR-format).
463 | indices (np.ndarray): Indices of the sparse matrix entries
464 | (CSR-format).
465 | out_degree (np.ndarray): Out-degree of each node.
466 | alpha (float): Alpha of the PageRank to calculate.
467 | eps (float): Threshold for PPR calculation stopping criterion
468 | (:obj:`edge_weight >= eps * out_degree`).
469 | :rtype: (:class:`List[List[int]]`, :class:`List[List[float]]`)
470 | """
471 | alpha_eps = alpha * eps
472 | js = []
473 | vals = []
474 | for inode in range(len(out_degree)):
475 | p = {inode: 0.0}
476 | r = {}
477 | r[inode] = alpha
478 | q = [inode]
479 | while len(q) > 0:
480 | unode = q.pop()
481 |
482 | res = r[unode] if unode in r else 0
483 | if unode in p:
484 | p[unode] += res
485 | else:
486 | p[unode] = res
487 | r[unode] = 0
488 | for vnode in indices[indptr[unode]:indptr[unode + 1]]:
489 | _val = (1 - alpha) * res / out_degree[unode]
490 | if vnode in r:
491 | r[vnode] += _val
492 | else:
493 | r[vnode] = _val
494 |
495 | res_vnode = r[vnode] if vnode in r else 0
496 | if res_vnode >= alpha_eps * out_degree[vnode]:
497 | if vnode not in q:
498 | q.append(vnode)
499 | js.append(list(p.keys()))
500 | vals.append(list(p.values()))
501 | return js, vals
502 |
503 | def __repr__(self):
504 | return '{}()'.format(self.__class__.__name__)
--------------------------------------------------------------------------------
/imports/preprocess_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Mwiza Kunda
2 | # Copyright (C) 2017 Sarah Parisot , Sofia Ira Ktena
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implcd ied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see .
16 |
17 |
18 | import os
19 | import warnings
20 | import glob
21 | import csv
22 | import re
23 | import numpy as np
24 | import scipy.io as sio
25 | import sys
26 | from nilearn import connectome
27 | import pandas as pd
28 | from scipy.spatial import distance
29 | from scipy import signal
30 | from sklearn.compose import ColumnTransformer
31 | from sklearn.preprocessing import Normalizer
32 | from sklearn.preprocessing import OrdinalEncoder
33 | from sklearn.preprocessing import OneHotEncoder
34 | from sklearn.preprocessing import StandardScaler
35 | warnings.filterwarnings("ignore")
36 |
37 | # Input data variables
38 |
39 | root_folder = '/home/azureuser/projects/BrainGNN/data/'
40 | data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal')
41 | phenotype = os.path.join(root_folder, 'ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv')
42 |
43 |
44 | def fetch_filenames(subject_IDs, file_type, atlas):
45 | """
46 | subject_list : list of short subject IDs in string format
47 | file_type : must be one of the available file types
48 | filemapping : resulting file name format
49 | returns:
50 | filenames : list of filetypes (same length as subject_list)
51 | """
52 |
53 | filemapping = {'func_preproc': '_func_preproc.nii.gz',
54 | 'rois_' + atlas: '_rois_' + atlas + '.1D'}
55 | # The list to be filled
56 | filenames = []
57 |
58 | # Fill list with requested file paths
59 | for i in range(len(subject_IDs)):
60 | os.chdir(data_folder)
61 | try:
62 | try:
63 | os.chdir(data_folder)
64 | filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0])
65 | except:
66 | os.chdir(data_folder + '/' + subject_IDs[i])
67 | filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0])
68 | except IndexError:
69 | filenames.append('N/A')
70 | return filenames
71 |
72 |
73 | # Get timeseries arrays for list of subjects
74 | def get_timeseries(subject_list, atlas_name, silence=False):
75 | """
76 | subject_list : list of short subject IDs in string format
77 | atlas_name : the atlas based on which the timeseries are generated e.g. aal, cc200
78 | returns:
79 | time_series : list of timeseries arrays, each of shape (timepoints x regions)
80 | """
81 |
82 | timeseries = []
83 | for i in range(len(subject_list)):
84 | subject_folder = os.path.join(data_folder, subject_list[i])
85 | ro_file = [f for f in os.listdir(subject_folder) if f.endswith('_rois_' + atlas_name + '.1D')]
86 | fl = os.path.join(subject_folder, ro_file[0])
87 | if silence != True:
88 | print("Reading timeseries file %s" % fl)
89 | timeseries.append(np.loadtxt(fl, skiprows=0))
90 |
91 | return timeseries
92 |
93 |
94 | # compute connectivity matrices
95 | def subject_connectivity(timeseries, subjects, atlas_name, kind, iter_no='', seed=1234,
96 | n_subjects='', save=True, save_path=data_folder):
97 | """
98 | timeseries : timeseries table for subject (timepoints x regions)
99 | subjects : subject IDs
100 | atlas_name : name of the parcellation atlas used
101 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation
102 | iter_no : tangent connectivity iteration number for cross validation evaluation
103 | save : save the connectivity matrix to a file
104 | save_path : specify path to save the matrix if different from subject folder
105 | returns:
106 | connectivity : connectivity matrix (regions x regions)
107 | """
108 |
109 | if kind in ['TPE', 'TE', 'correlation','partial correlation']:
110 | if kind not in ['TPE', 'TE']:
111 | conn_measure = connectome.ConnectivityMeasure(kind=kind)
112 | connectivity = conn_measure.fit_transform(timeseries)
113 | else:
114 | if kind == 'TPE':
115 | conn_measure = connectome.ConnectivityMeasure(kind='correlation')
116 | conn_mat = conn_measure.fit_transform(timeseries)
117 | conn_measure = connectome.ConnectivityMeasure(kind='tangent')
118 | connectivity_fit = conn_measure.fit(conn_mat)
119 | connectivity = connectivity_fit.transform(conn_mat)
120 | else:
121 | conn_measure = connectome.ConnectivityMeasure(kind='tangent')
122 | connectivity_fit = conn_measure.fit(timeseries)
123 | connectivity = connectivity_fit.transform(timeseries)
124 |
125 | if save:
126 | if kind not in ['TPE', 'TE']:
127 | for i, subj_id in enumerate(subjects):
128 | subject_file = os.path.join(save_path, subj_id,
129 | subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '.mat')
130 | sio.savemat(subject_file, {'connectivity': connectivity[i]})
131 | return connectivity
132 | else:
133 | for i, subj_id in enumerate(subjects):
134 | subject_file = os.path.join(save_path, subj_id,
135 | subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '_' + str(
136 | iter_no) + '_' + str(seed) + '_' + validation_ext + str(
137 | n_subjects) + '.mat')
138 | sio.savemat(subject_file, {'connectivity': connectivity[i]})
139 | return connectivity_fit
140 |
141 |
142 | # Get the list of subject IDs
143 |
144 | def get_ids(num_subjects=None):
145 | """
146 | return:
147 | subject_IDs : list of all subject IDs
148 | """
149 |
150 | subject_IDs = np.genfromtxt(os.path.join(data_folder, 'subject_IDs.txt'), dtype=str)
151 |
152 | if num_subjects is not None:
153 | subject_IDs = subject_IDs[:num_subjects]
154 |
155 | return subject_IDs
156 |
157 |
158 | # Get phenotype values for a list of subjects
159 | def get_subject_score(subject_list, score):
160 | scores_dict = {}
161 |
162 | with open(phenotype) as csv_file:
163 | reader = csv.DictReader(csv_file)
164 | for row in reader:
165 | if row['SUB_ID'] in subject_list:
166 | if score == 'HANDEDNESS_CATEGORY':
167 | if (row[score].strip() == '-9999') or (row[score].strip() == ''):
168 | scores_dict[row['SUB_ID']] = 'R'
169 | elif row[score] == 'Mixed':
170 | scores_dict[row['SUB_ID']] = 'Ambi'
171 | elif row[score] == 'L->R':
172 | scores_dict[row['SUB_ID']] = 'Ambi'
173 | else:
174 | scores_dict[row['SUB_ID']] = row[score]
175 | elif (score == 'FIQ' or score == 'PIQ' or score == 'VIQ'):
176 | if (row[score].strip() == '-9999') or (row[score].strip() == ''):
177 | scores_dict[row['SUB_ID']] = 100
178 | else:
179 | scores_dict[row['SUB_ID']] = float(row[score])
180 |
181 | else:
182 | scores_dict[row['SUB_ID']] = row[score]
183 |
184 | return scores_dict
185 |
186 |
187 | # preprocess phenotypes. Categorical -> ordinal representation
188 | def preprocess_phenotypes(pheno_ft, params):
189 | if params['model'] == 'MIDA':
190 | ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2])], remainder='passthrough')
191 | else:
192 | ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2, 3])], remainder='passthrough')
193 |
194 | pheno_ft = ct.fit_transform(pheno_ft)
195 | pheno_ft = pheno_ft.astype('float32')
196 |
197 | return (pheno_ft)
198 |
199 |
200 | # create phenotype feature vector to concatenate with fmri feature vectors
201 | def phenotype_ft_vector(pheno_ft, num_subjects, params):
202 | gender = pheno_ft[:, 0]
203 | if params['model'] == 'MIDA':
204 | eye = pheno_ft[:, 0]
205 | hand = pheno_ft[:, 2]
206 | age = pheno_ft[:, 3]
207 | fiq = pheno_ft[:, 4]
208 | else:
209 | eye = pheno_ft[:, 2]
210 | hand = pheno_ft[:, 3]
211 | age = pheno_ft[:, 4]
212 | fiq = pheno_ft[:, 5]
213 |
214 | phenotype_ft = np.zeros((num_subjects, 4))
215 | phenotype_ft_eye = np.zeros((num_subjects, 2))
216 | phenotype_ft_hand = np.zeros((num_subjects, 3))
217 |
218 | for i in range(num_subjects):
219 | phenotype_ft[i, int(gender[i])] = 1
220 | phenotype_ft[i, -2] = age[i]
221 | phenotype_ft[i, -1] = fiq[i]
222 | phenotype_ft_eye[i, int(eye[i])] = 1
223 | phenotype_ft_hand[i, int(hand[i])] = 1
224 |
225 | if params['model'] == 'MIDA':
226 | phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand], axis=1)
227 | else:
228 | phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand, phenotype_ft_eye], axis=1)
229 |
230 | return phenotype_ft
231 |
232 |
233 | # Load precomputed fMRI connectivity networks
234 | def get_networks(subject_list, kind, iter_no='', seed=1234, n_subjects='', atlas_name="aal",
235 | variable='connectivity'):
236 | """
237 | subject_list : list of subject IDs
238 | kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation
239 | atlas_name : name of the parcellation atlas used
240 | variable : variable name in the .mat file that has been used to save the precomputed networks
241 | return:
242 | matrix : feature matrix of connectivity networks (num_subjects x network_size)
243 | """
244 |
245 | all_networks = []
246 | for subject in subject_list:
247 | if len(kind.split()) == 2:
248 | kind = '_'.join(kind.split())
249 | fl = os.path.join(data_folder, subject,
250 | subject + "_" + atlas_name + "_" + kind.replace(' ', '_') + ".mat")
251 |
252 |
253 | matrix = sio.loadmat(fl)[variable]
254 | all_networks.append(matrix)
255 |
256 | if kind in ['TE', 'TPE']:
257 | norm_networks = [mat for mat in all_networks]
258 | else:
259 | norm_networks = [np.arctanh(mat) for mat in all_networks]
260 |
261 | networks = np.stack(norm_networks)
262 |
263 | return networks
264 |
265 |
--------------------------------------------------------------------------------
/imports/read_abide_stats_parall.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Xiaoxiao Li
3 | Date: 2019/02/24
4 | '''
5 |
6 | import os.path as osp
7 | from os import listdir
8 | import os
9 | import glob
10 | import h5py
11 |
12 | import torch
13 | import numpy as np
14 | from scipy.io import loadmat
15 | from torch_geometric.data import Data
16 | import networkx as nx
17 | from networkx.convert_matrix import from_numpy_matrix
18 | import multiprocessing
19 | from torch_sparse import coalesce
20 | from torch_geometric.utils import remove_self_loops
21 | from functools import partial
22 | import deepdish as dd
23 | from imports.gdc import GDC
24 |
25 |
26 | def split(data, batch):
27 | node_slice = torch.cumsum(torch.from_numpy(np.bincount(batch)), 0)
28 | node_slice = torch.cat([torch.tensor([0]), node_slice])
29 |
30 | row, _ = data.edge_index
31 | edge_slice = torch.cumsum(torch.from_numpy(np.bincount(batch[row])), 0)
32 | edge_slice = torch.cat([torch.tensor([0]), edge_slice])
33 |
34 | # Edge indices should start at zero for every graph.
35 | data.edge_index -= node_slice[batch[row]].unsqueeze(0)
36 |
37 | slices = {'edge_index': edge_slice}
38 | if data.x is not None:
39 | slices['x'] = node_slice
40 | if data.edge_attr is not None:
41 | slices['edge_attr'] = edge_slice
42 | if data.y is not None:
43 | if data.y.size(0) == batch.size(0):
44 | slices['y'] = node_slice
45 | else:
46 | slices['y'] = torch.arange(0, batch[-1] + 2, dtype=torch.long)
47 | if data.pos is not None:
48 | slices['pos'] = node_slice
49 |
50 | return data, slices
51 |
52 |
53 | def cat(seq):
54 | seq = [item for item in seq if item is not None]
55 | seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq]
56 | return torch.cat(seq, dim=-1).squeeze() if len(seq) > 0 else None
57 |
58 | class NoDaemonProcess(multiprocessing.Process):
59 | @property
60 | def daemon(self):
61 | return False
62 |
63 | @daemon.setter
64 | def daemon(self, value):
65 | pass
66 |
67 |
68 | class NoDaemonContext(type(multiprocessing.get_context())):
69 | Process = NoDaemonProcess
70 |
71 |
72 | def read_data(data_dir):
73 | onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))]
74 | onlyfiles.sort()
75 | batch = []
76 | pseudo = []
77 | y_list = []
78 | edge_att_list, edge_index_list,att_list = [], [], []
79 |
80 | # parallar computing
81 | cores = multiprocessing.cpu_count()
82 | pool = multiprocessing.Pool(processes=cores)
83 | #pool = MyPool(processes = cores)
84 | func = partial(read_sigle_data, data_dir)
85 |
86 | import timeit
87 |
88 | start = timeit.default_timer()
89 |
90 | res = pool.map(func, onlyfiles)
91 |
92 | pool.close()
93 | pool.join()
94 |
95 | stop = timeit.default_timer()
96 |
97 | print('Time: ', stop - start)
98 |
99 |
100 |
101 | for j in range(len(res)):
102 | edge_att_list.append(res[j][0])
103 | edge_index_list.append(res[j][1]+j*res[j][4])
104 | att_list.append(res[j][2])
105 | y_list.append(res[j][3])
106 | batch.append([j]*res[j][4])
107 | pseudo.append(np.diag(np.ones(res[j][4])))
108 |
109 | edge_att_arr = np.concatenate(edge_att_list)
110 | edge_index_arr = np.concatenate(edge_index_list, axis=1)
111 | att_arr = np.concatenate(att_list, axis=0)
112 | pseudo_arr = np.concatenate(pseudo, axis=0)
113 | y_arr = np.stack(y_list)
114 | edge_att_torch = torch.from_numpy(edge_att_arr.reshape(len(edge_att_arr), 1)).float()
115 | att_torch = torch.from_numpy(att_arr).float()
116 | y_torch = torch.from_numpy(y_arr).long() # classification
117 | batch_torch = torch.from_numpy(np.hstack(batch)).long()
118 | edge_index_torch = torch.from_numpy(edge_index_arr).long()
119 | pseudo_torch = torch.from_numpy(pseudo_arr).float()
120 | data = Data(x=att_torch, edge_index=edge_index_torch, y=y_torch, edge_attr=edge_att_torch, pos = pseudo_torch )
121 |
122 |
123 | data, slices = split(data, batch_torch)
124 |
125 | return data, slices
126 |
127 |
128 | def read_sigle_data(data_dir,filename,use_gdc =False):
129 |
130 | temp = dd.io.load(osp.join(data_dir, filename))
131 |
132 | # read edge and edge attribute
133 | pcorr = np.abs(temp['pcorr'][()])
134 |
135 | num_nodes = pcorr.shape[0]
136 | G = from_numpy_matrix(pcorr)
137 | A = nx.to_scipy_sparse_matrix(G)
138 | adj = A.tocoo()
139 | edge_att = np.zeros(len(adj.row))
140 | for i in range(len(adj.row)):
141 | edge_att[i] = pcorr[adj.row[i], adj.col[i]]
142 |
143 | edge_index = np.stack([adj.row, adj.col])
144 | edge_index, edge_att = remove_self_loops(torch.from_numpy(edge_index), torch.from_numpy(edge_att))
145 | edge_index = edge_index.long()
146 | edge_index, edge_att = coalesce(edge_index, edge_att, num_nodes,
147 | num_nodes)
148 | att = temp['corr'][()]
149 | label = temp['label'][()]
150 |
151 | att_torch = torch.from_numpy(att).float()
152 | y_torch = torch.from_numpy(np.array(label)).long() # classification
153 |
154 | data = Data(x=att_torch, edge_index=edge_index.long(), y=y_torch, edge_attr=edge_att)
155 |
156 | if use_gdc:
157 | '''
158 | Implementation of https://papers.nips.cc/paper/2019/hash/23c894276a2c5a16470e6a31f4618d73-Abstract.html
159 | '''
160 | data.edge_attr = data.edge_attr.squeeze()
161 | gdc = GDC(self_loop_weight=1, normalization_in='sym',
162 | normalization_out='col',
163 | diffusion_kwargs=dict(method='ppr', alpha=0.2),
164 | sparsification_kwargs=dict(method='topk', k=20,
165 | dim=0), exact=True)
166 | data = gdc(data)
167 | return data.edge_attr.data.numpy(),data.edge_index.data.numpy(),data.x.data.numpy(),data.y.data.item(),num_nodes
168 |
169 | else:
170 | return edge_att.data.numpy(),edge_index.data.numpy(),att,label,num_nodes
171 |
172 | if __name__ == "__main__":
173 | data_dir = '/home/azureuser/projects/BrainGNN/data/ABIDE_pcp/cpac/filt_noglobal/raw'
174 | filename = '50346.h5'
175 | read_sigle_data(data_dir, filename)
176 |
177 |
178 |
179 |
180 |
181 |
182 |
--------------------------------------------------------------------------------
/imports/utils.py:
--------------------------------------------------------------------------------
1 | from scipy import stats
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import torch
5 | from scipy.io import loadmat
6 | from sklearn.model_selection import StratifiedKFold
7 | from sklearn.model_selection import KFold
8 |
9 |
10 | def train_val_test_split(kfold = 5, fold = 0):
11 | n_sub = 1035
12 | id = list(range(n_sub))
13 |
14 |
15 | import random
16 | random.seed(123)
17 | random.shuffle(id)
18 |
19 | kf = KFold(n_splits=kfold, random_state=123,shuffle = True)
20 | kf2 = KFold(n_splits=kfold-1, shuffle=True, random_state = 666)
21 |
22 |
23 | test_index = list()
24 | train_index = list()
25 | val_index = list()
26 |
27 | for tr,te in kf.split(np.array(id)):
28 | test_index.append(te)
29 | tr, val = list(kf2.split(tr))[0]
30 | train_index.append(tr)
31 | val_index.append(val)
32 |
33 | train_id = train_index[fold]
34 | test_id = test_index[fold]
35 | val_id = val_index[fold]
36 |
37 | return train_id,val_id,test_id
--------------------------------------------------------------------------------
/net/braingnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch_geometric.nn import TopKPooling
5 | from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
6 | from torch_geometric.utils import (add_self_loops, sort_edge_index,
7 | remove_self_loops)
8 | from torch_sparse import spspmm
9 |
10 | from net.braingraphconv import MyNNConv
11 |
12 |
13 | ##########################################################################################################################
14 | class Network(torch.nn.Module):
15 | def __init__(self, indim, ratio, nclass, k=8, R=200):
16 | '''
17 |
18 | :param indim: (int) node feature dimension
19 | :param ratio: (float) pooling ratio in (0,1)
20 | :param nclass: (int) number of classes
21 | :param k: (int) number of communities
22 | :param R: (int) number of ROIs
23 | '''
24 | super(Network, self).__init__()
25 |
26 | self.indim = indim
27 | self.dim1 = 32
28 | self.dim2 = 32
29 | self.dim3 = 512
30 | self.dim4 = 256
31 | self.dim5 = 8
32 | self.k = k
33 | self.R = R
34 |
35 | self.n1 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim1 * self.indim))
36 | self.conv1 = MyNNConv(self.indim, self.dim1, self.n1, normalize=False)
37 | self.pool1 = TopKPooling(self.dim1, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid)
38 | self.n2 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim2 * self.dim1))
39 | self.conv2 = MyNNConv(self.dim1, self.dim2, self.n2, normalize=False)
40 | self.pool2 = TopKPooling(self.dim2, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid)
41 |
42 | #self.fc1 = torch.nn.Linear((self.dim2) * 2, self.dim2)
43 | self.fc1 = torch.nn.Linear((self.dim1+self.dim2)*2, self.dim2)
44 | self.bn1 = torch.nn.BatchNorm1d(self.dim2)
45 | self.fc2 = torch.nn.Linear(self.dim2, self.dim3)
46 | self.bn2 = torch.nn.BatchNorm1d(self.dim3)
47 | self.fc3 = torch.nn.Linear(self.dim3, nclass)
48 |
49 |
50 |
51 |
52 | def forward(self, x, edge_index, batch, edge_attr, pos):
53 |
54 | x = self.conv1(x, edge_index, edge_attr, pos)
55 | x, edge_index, edge_attr, batch, perm, score1 = self.pool1(x, edge_index, edge_attr, batch)
56 |
57 | pos = pos[perm]
58 | x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
59 |
60 | edge_attr = edge_attr.squeeze()
61 | edge_index, edge_attr = self.augment_adj(edge_index, edge_attr, x.size(0))
62 |
63 | x = self.conv2(x, edge_index, edge_attr, pos)
64 | x, edge_index, edge_attr, batch, perm, score2 = self.pool2(x, edge_index,edge_attr, batch)
65 |
66 | x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
67 |
68 | x = torch.cat([x1,x2], dim=1)
69 | x = self.bn1(F.relu(self.fc1(x)))
70 | x = F.dropout(x, p=0.5, training=self.training)
71 | x = self.bn2(F.relu(self.fc2(x)))
72 | x= F.dropout(x, p=0.5, training=self.training)
73 | x = F.log_softmax(self.fc3(x), dim=-1)
74 |
75 | return x,self.pool1.weight,self.pool2.weight, torch.sigmoid(score1).view(x.size(0),-1), torch.sigmoid(score2).view(x.size(0),-1)
76 |
77 | def augment_adj(self, edge_index, edge_weight, num_nodes):
78 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
79 | num_nodes=num_nodes)
80 | edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
81 | num_nodes)
82 | edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
83 | edge_weight, num_nodes, num_nodes,
84 | num_nodes)
85 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
86 | return edge_index, edge_weight
87 |
88 |
--------------------------------------------------------------------------------
/net/braingraphconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn import Parameter
4 | from net.brainmsgpassing import MyMessagePassing
5 | from torch_geometric.utils import add_remaining_self_loops,softmax
6 |
7 | from torch_geometric.typing import (OptTensor)
8 |
9 | from net.inits import uniform
10 |
11 |
12 | class MyNNConv(MyMessagePassing):
13 | def __init__(self, in_channels, out_channels, nn, normalize=False, bias=True,
14 | **kwargs):
15 | super(MyNNConv, self).__init__(aggr='mean', **kwargs)
16 |
17 | self.in_channels = in_channels
18 | self.out_channels = out_channels
19 | self.normalize = normalize
20 | self.nn = nn
21 | #self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
22 |
23 | if bias:
24 | self.bias = Parameter(torch.Tensor(out_channels))
25 | else:
26 | self.register_parameter('bias', None)
27 |
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | # uniform(self.in_channels, self.weight)
32 | uniform(self.in_channels, self.bias)
33 |
34 | def forward(self, x, edge_index, edge_weight=None, pseudo= None, size=None):
35 | """"""
36 | edge_weight = edge_weight.squeeze()
37 | if size is None and torch.is_tensor(x):
38 | edge_index, edge_weight = add_remaining_self_loops(
39 | edge_index, edge_weight, 1, x.size(0))
40 |
41 | weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels)
42 | if torch.is_tensor(x):
43 | x = torch.matmul(x.unsqueeze(1), weight).squeeze(1)
44 | else:
45 | x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1),
46 | None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1))
47 |
48 | # weight = self.nn(pseudo).view(-1, self.out_channels,self.in_channels)
49 | # if torch.is_tensor(x):
50 | # x = torch.matmul(x.unsqueeze(1), weight.permute(0,2,1)).squeeze(1)
51 | # else:
52 | # x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1),
53 | # None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1))
54 |
55 | return self.propagate(edge_index, size=size, x=x,
56 | edge_weight=edge_weight)
57 |
58 | def message(self, edge_index_i, size_i, x_j, edge_weight, ptr: OptTensor):
59 | edge_weight = softmax(edge_weight, edge_index_i, ptr, size_i)
60 | return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
61 |
62 | def update(self, aggr_out):
63 | if self.bias is not None:
64 | aggr_out = aggr_out + self.bias
65 | if self.normalize:
66 | aggr_out = F.normalize(aggr_out, p=2, dim=-1)
67 | return aggr_out
68 |
69 | def __repr__(self):
70 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
71 | self.out_channels)
72 |
73 |
--------------------------------------------------------------------------------
/net/brainmsgpassing.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import inspect
3 |
4 | import torch
5 | # from torch_geometric.utils import scatter_
6 | from torch_scatter import scatter,scatter_add
7 |
8 | special_args = [
9 | 'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j'
10 | ]
11 | __size_error_msg__ = ('All tensors which should get mapped to the same source '
12 | 'or target nodes must be of same size in dimension 0.')
13 |
14 | is_python2 = sys.version_info[0] < 3
15 | getargspec = inspect.getargspec if is_python2 else inspect.getfullargspec
16 |
17 |
18 | class MyMessagePassing(torch.nn.Module):
19 | r"""Base class for creating message passing layers
20 | .. math::
21 | \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
22 | \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
23 | \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right),
24 | where :math:`\square` denotes a differentiable, permutation invariant
25 | function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
26 | and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
27 | MLPs.
28 | See `here `__ for the accompanying tutorial.
30 | Args:
31 | aggr (string, optional): The aggregation scheme to use
32 | (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`).
33 | (default: :obj:`"add"`)
34 | flow (string, optional): The flow direction of message passing
35 | (:obj:`"source_to_target"` or :obj:`"target_to_source"`).
36 | (default: :obj:`"source_to_target"`)
37 | node_dim (int, optional): The axis along which to propagate.
38 | (default: :obj:`0`)
39 | """
40 | def __init__(self, aggr='add', flow='source_to_target', node_dim=0):
41 | super(MyMessagePassing, self).__init__()
42 |
43 | self.aggr = aggr
44 | assert self.aggr in ['add', 'mean', 'max']
45 |
46 | self.flow = flow
47 | assert self.flow in ['source_to_target', 'target_to_source']
48 |
49 | self.node_dim = node_dim
50 | assert self.node_dim >= 0
51 |
52 | self.__message_args__ = getargspec(self.message)[0][1:]
53 | self.__special_args__ = [(i, arg)
54 | for i, arg in enumerate(self.__message_args__)
55 | if arg in special_args]
56 | self.__message_args__ = [
57 | arg for arg in self.__message_args__ if arg not in special_args
58 | ]
59 | self.__update_args__ = getargspec(self.update)[0][2:]
60 |
61 | def propagate(self, edge_index, size=None, **kwargs):
62 | r"""The initial call to start propagating messages.
63 | Args:
64 | edge_index (Tensor): The indices of a general (sparse) assignment
65 | matrix with shape :obj:`[N, M]` (can be directed or
66 | undirected).
67 | size (list or tuple, optional): The size :obj:`[N, M]` of the
68 | assignment matrix. If set to :obj:`None`, the size is tried to
69 | get automatically inferred and assumed to be symmetric.
70 | (default: :obj:`None`)
71 | **kwargs: Any additional data which is needed to construct messages
72 | and to update node embeddings.
73 | """
74 |
75 | dim = self.node_dim
76 | size = [None, None] if size is None else list(size)
77 | assert len(size) == 2
78 |
79 | i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
80 | ij = {"_i": i, "_j": j}
81 |
82 | message_args = []
83 | for arg in self.__message_args__:
84 | if arg[-2:] in ij.keys():
85 | tmp = kwargs.get(arg[:-2], None)
86 | if tmp is None: # pragma: no cover
87 | message_args.append(tmp)
88 | else:
89 | idx = ij[arg[-2:]]
90 | if isinstance(tmp, tuple) or isinstance(tmp, list):
91 | assert len(tmp) == 2
92 | if tmp[1 - idx] is not None:
93 | if size[1 - idx] is None:
94 | size[1 - idx] = tmp[1 - idx].size(dim)
95 | if size[1 - idx] != tmp[1 - idx].size(dim):
96 | raise ValueError(__size_error_msg__)
97 | tmp = tmp[idx]
98 |
99 | if tmp is None:
100 | message_args.append(tmp)
101 | else:
102 | if size[idx] is None:
103 | size[idx] = tmp.size(dim)
104 | if size[idx] != tmp.size(dim):
105 | raise ValueError(__size_error_msg__)
106 |
107 | tmp = torch.index_select(tmp, dim, edge_index[idx])
108 | message_args.append(tmp)
109 | else:
110 | message_args.append(kwargs.get(arg, None))
111 |
112 | size[0] = size[1] if size[0] is None else size[0]
113 | size[1] = size[0] if size[1] is None else size[1]
114 |
115 | kwargs['edge_index'] = edge_index
116 | kwargs['size'] = size
117 |
118 | for (idx, arg) in self.__special_args__:
119 | if arg[-2:] in ij.keys():
120 | message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]])
121 | else:
122 | message_args.insert(idx, kwargs[arg])
123 |
124 | update_args = [kwargs[arg] for arg in self.__update_args__]
125 |
126 | out = self.message(*message_args)
127 | # out = scatter_(self.aggr, out, edge_index[i], dim, dim_size=size[i])
128 | out = scatter_add(out, edge_index[i], dim, dim_size=size[i])
129 | out = self.update(out, *update_args)
130 |
131 | return out
132 |
133 | def message(self, x_j): # pragma: no cover
134 | r"""Constructs messages to node :math:`i` in analogy to
135 | :math:`\phi_{\mathbf{\Theta}}` for each edge in
136 | :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and
137 | :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`.
138 | Can take any argument which was initially passed to :meth:`propagate`.
139 | In addition, tensors passed to :meth:`propagate` can be mapped to the
140 | respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
141 | :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
142 | """
143 |
144 | return x_j
145 |
146 | def update(self, aggr_out): # pragma: no cover
147 | r"""Updates node embeddings in analogy to
148 | :math:`\gamma_{\mathbf{\Theta}}` for each node
149 | :math:`i \in \mathcal{V}`.
150 | Takes in the output of aggregation as first argument and any argument
151 | which was initially passed to :meth:`propagate`."""
152 |
153 | return aggr_out
154 |
--------------------------------------------------------------------------------
/net/inits.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def uniform(size, tensor):
5 | bound = 1.0 / math.sqrt(size)
6 | if tensor is not None:
7 | tensor.data.uniform_(-bound, bound)
8 |
9 |
10 | def kaiming_uniform(tensor, fan, a):
11 | if tensor is not None:
12 | bound = math.sqrt(6 / ((1 + a**2) * fan))
13 | tensor.data.uniform_(-bound, bound)
14 |
15 |
16 | def glorot(tensor):
17 | if tensor is not None:
18 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
19 | tensor.data.uniform_(-stdv, stdv)
20 |
21 |
22 | def zeros(tensor):
23 | if tensor is not None:
24 | tensor.data.fill_(0)
25 |
26 |
27 | def ones(tensor):
28 | if tensor is not None:
29 | tensor.data.fill_(1)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | alabaster @ file:///home/ktietz/src/ci/alabaster_1611921544520/work
2 | anaconda-client==1.7.2
3 | anaconda-project @ file:///tmp/build/80754af9/anaconda-project_1610472525955/work
4 | anyio @ file:///tmp/build/80754af9/anyio_1617783275907/work/dist
5 | appdirs==1.4.4
6 | argh==0.26.2
7 | argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613037097816/work
8 | arrow==0.13.1
9 | ase==3.21.1
10 | asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work
11 | astroid @ file:///tmp/build/80754af9/astroid_1613500854201/work
12 | astropy @ file:///tmp/build/80754af9/astropy_1617745353437/work
13 | async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work
14 | atomicwrites==1.4.0
15 | attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work
16 | autopep8 @ file:///tmp/build/80754af9/autopep8_1615918855173/work
17 | Babel @ file:///tmp/build/80754af9/babel_1607110387436/work
18 | backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
19 | backports.shutil-get-terminal-size @ file:///tmp/build/80754af9/backports.shutil_get_terminal_size_1608222128777/work
20 | beautifulsoup4 @ file:///home/linux1/recipes/ci/beautifulsoup4_1610988766420/work
21 | binaryornot @ file:///tmp/build/80754af9/binaryornot_1617751525010/work
22 | bitarray @ file:///tmp/build/80754af9/bitarray_1618431750766/work
23 | bkcharts==0.2
24 | black==19.10b0
25 | bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work
26 | bokeh @ file:///tmp/build/80754af9/bokeh_1617824541184/work
27 | boto==2.49.0
28 | Bottleneck==1.3.2
29 | brotlipy==0.7.0
30 | certifi==2020.12.5
31 | cffi @ file:///tmp/build/80754af9/cffi_1613246945912/work
32 | chardet @ file:///tmp/build/80754af9/chardet_1607706746162/work
33 | click @ file:///home/linux1/recipes/ci/click_1610990599742/work
34 | cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work
35 | clyent==1.2.2
36 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work
37 | contextlib2==0.6.0.post1
38 | cookiecutter @ file:///tmp/build/80754af9/cookiecutter_1617748928239/work
39 | cryptography @ file:///tmp/build/80754af9/cryptography_1616769286105/work
40 | cycler==0.10.0
41 | Cython @ file:///tmp/build/80754af9/cython_1618435160151/work
42 | cytoolz==0.11.0
43 | dask @ file:///tmp/build/80754af9/dask-core_1617390489108/work
44 | decorator @ file:///tmp/build/80754af9/decorator_1617916966915/work
45 | deepdish==0.3.6
46 | defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
47 | diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work
48 | distributed @ file:///tmp/build/80754af9/distributed_1617381497899/work
49 | docutils @ file:///tmp/build/80754af9/docutils_1617624660125/work
50 | entrypoints==0.3
51 | et-xmlfile==1.0.1
52 | fastcache==1.1.0
53 | filelock @ file:///home/linux1/recipes/ci/filelock_1610993975404/work
54 | flake8 @ file:///tmp/build/80754af9/flake8_1615834841867/work
55 | Flask @ file:///home/ktietz/src/ci/flask_1611932660458/work
56 | fsspec @ file:///tmp/build/80754af9/fsspec_1617959894824/work
57 | future==0.18.2
58 | gevent @ file:///tmp/build/80754af9/gevent_1616770671827/work
59 | glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work
60 | gmpy2==2.0.8
61 | googledrivedownloader==0.4
62 | greenlet @ file:///tmp/build/80754af9/greenlet_1611957705398/work
63 | h5py==2.10.0
64 | HeapDict==1.0.1
65 | html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work
66 | idna @ file:///home/linux1/recipes/ci/idna_1610986105248/work
67 | imageio @ file:///tmp/build/80754af9/imageio_1617700267927/work
68 | imagesize @ file:///home/ktietz/src/ci/imagesize_1611921604382/work
69 | importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617874469820/work
70 | inflection==0.5.1
71 | iniconfig @ file:///home/linux1/recipes/ci/iniconfig_1610983019677/work
72 | intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work
73 | ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl
74 | ipython @ file:///tmp/build/80754af9/ipython_1617120885885/work
75 | ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
76 | ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work
77 | isodate==0.6.0
78 | isort @ file:///tmp/build/80754af9/isort_1616355431277/work
79 | itsdangerous @ file:///home/ktietz/src/ci/itsdangerous_1611932585308/work
80 | jdcal==1.4.1
81 | jedi @ file:///tmp/build/80754af9/jedi_1606932564285/work
82 | jeepney @ file:///tmp/build/80754af9/jeepney_1606148855031/work
83 | Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work
84 | jinja2-time @ file:///tmp/build/80754af9/jinja2-time_1617751524098/work
85 | joblib @ file:///tmp/build/80754af9/joblib_1613502643832/work
86 | json5==0.9.5
87 | jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
88 | jupyter==1.0.0
89 | jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work
90 | jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work
91 | jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213311222/work
92 | jupyter-packaging @ file:///tmp/build/80754af9/jupyter-packaging_1613502826984/work
93 | jupyter-server @ file:///tmp/build/80754af9/jupyter_server_1616083640759/work
94 | jupyterlab @ file:///tmp/build/80754af9/jupyterlab_1619133235951/work
95 | jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
96 | jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1617134334258/work
97 | jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work
98 | keyring @ file:///tmp/build/80754af9/keyring_1614616740399/work
99 | kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work
100 | lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1616526917483/work
101 | libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work
102 | llvmlite==0.36.0
103 | locket==0.2.1
104 | lxml @ file:///tmp/build/80754af9/lxml_1616443220220/work
105 | MarkupSafe==1.1.1
106 | matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1613407855456/work
107 | mccabe==0.6.1
108 | mistune==0.8.4
109 | mkl-fft==1.3.0
110 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853849286/work
111 | mkl-service==2.3.0
112 | mock @ file:///tmp/build/80754af9/mock_1607622725907/work
113 | more-itertools @ file:///tmp/build/80754af9/more-itertools_1613676688952/work
114 | mpmath==1.2.1
115 | msgpack @ file:///tmp/build/80754af9/msgpack-python_1612287151062/work
116 | multipledispatch==0.6.0
117 | mypy-extensions==0.4.3
118 | nbclassic @ file:///tmp/build/80754af9/nbclassic_1616085367084/work
119 | nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work
120 | nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914830498/work
121 | nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work
122 | nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
123 | networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work
124 | nibabel==3.2.1
125 | nilearn==0.7.1
126 | nltk @ file:///tmp/build/80754af9/nltk_1618327084230/work
127 | nose @ file:///tmp/build/80754af9/nose_1606773131901/work
128 | notebook @ file:///tmp/build/80754af9/notebook_1616443462982/work
129 | numba @ file:///tmp/build/80754af9/numba_1616774046117/work
130 | numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work
131 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1618497241363/work
132 | numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work
133 | olefile==0.46
134 | openpyxl @ file:///tmp/build/80754af9/openpyxl_1615411699337/work
135 | packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work
136 | pandas==1.2.4
137 | pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work
138 | parso==0.7.0
139 | partd @ file:///tmp/build/80754af9/partd_1618000087440/work
140 | path @ file:///tmp/build/80754af9/path_1614022220526/work
141 | pathlib2 @ file:///tmp/build/80754af9/pathlib2_1607024983162/work
142 | pathspec==0.7.0
143 | pathtools==0.1.2
144 | patsy==0.5.1
145 | pep8==1.7.1
146 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
147 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
148 | Pillow @ file:///tmp/build/80754af9/pillow_1617383569452/work
149 | pkginfo==1.7.0
150 | pluggy @ file:///tmp/build/80754af9/pluggy_1615976321666/work
151 | ply==3.11
152 | poyo @ file:///tmp/build/80754af9/poyo_1617751526755/work
153 | prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1618088486455/work
154 | prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work
155 | protobuf==3.17.0
156 | psutil @ file:///tmp/build/80754af9/psutil_1612298023621/work
157 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
158 | py @ file:///tmp/build/80754af9/py_1607971587848/work
159 | pycodestyle @ file:///home/ktietz/src/ci_mi/pycodestyle_1612807597675/work
160 | pycosat==0.6.3
161 | pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
162 | pycurl==7.43.0.6
163 | pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1616182067796/work
164 | pyerfa @ file:///tmp/build/80754af9/pyerfa_1619390903914/work
165 | pyflakes @ file:///home/ktietz/src/ci_ipy2/pyflakes_1612551159640/work
166 | Pygments @ file:///tmp/build/80754af9/pygments_1615143339740/work
167 | pylint @ file:///tmp/build/80754af9/pylint_1617135829881/work
168 | pyls-black @ file:///tmp/build/80754af9/pyls-black_1607553132291/work
169 | pyls-spyder @ file:///tmp/build/80754af9/pyls-spyder_1613849700860/work
170 | pyodbc===4.0.0-unsupported
171 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1608057966937/work
172 | pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work
173 | pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work
174 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
175 | pytest==6.2.3
176 | python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work
177 | python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work
178 | python-language-server @ file:///tmp/build/80754af9/python-language-server_1607972495879/work
179 | python-louvain==0.15
180 | python-slugify @ file:///tmp/build/80754af9/python-slugify_1620405669636/work
181 | pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
182 | PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work
183 | pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work
184 | PyYAML==5.4.1
185 | pyzmq==20.0.0
186 | QDarkStyle @ file:///tmp/build/80754af9/qdarkstyle_1617386714626/work
187 | qstylizer @ file:///tmp/build/80754af9/qstylizer_1617713584600/work/dist/qstylizer-0.1.10-py2.py3-none-any.whl
188 | QtAwesome @ file:///tmp/build/80754af9/qtawesome_1615991616277/work
189 | qtconsole @ file:///tmp/build/80754af9/qtconsole_1616775094278/work
190 | QtPy==1.9.0
191 | rdflib==5.0.0
192 | regex @ file:///tmp/build/80754af9/regex_1617569202463/work
193 | requests @ file:///tmp/build/80754af9/requests_1608241421344/work
194 | rope @ file:///tmp/build/80754af9/rope_1602264064449/work
195 | Rtree @ file:///tmp/build/80754af9/rtree_1618420845272/work
196 | ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016699510/work
197 | scikit-image==0.16.2
198 | scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1614446682169/work
199 | scipy @ file:///tmp/build/80754af9/scipy_1618855647378/work
200 | seaborn @ file:///tmp/build/80754af9/seaborn_1608578541026/work
201 | SecretStorage @ file:///tmp/build/80754af9/secretstorage_1614022784285/work
202 | Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work
203 | simplegeneric==0.8.1
204 | singledispatch @ file:///tmp/build/80754af9/singledispatch_1614366001199/work
205 | sip==4.19.13
206 | six @ file:///tmp/build/80754af9/six_1605205327372/work
207 | sniffio @ file:///tmp/build/80754af9/sniffio_1614030475067/work
208 | snowballstemmer @ file:///tmp/build/80754af9/snowballstemmer_1611258885636/work
209 | sortedcollections @ file:///tmp/build/80754af9/sortedcollections_1611172717284/work
210 | sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1606865132123/work
211 | soupsieve @ file:///tmp/build/80754af9/soupsieve_1616183228191/work
212 | Sphinx @ file:///tmp/build/80754af9/sphinx_1616268783226/work
213 | sphinxcontrib-applehelp @ file:///home/ktietz/src/ci/sphinxcontrib-applehelp_1611920841464/work
214 | sphinxcontrib-devhelp @ file:///home/ktietz/src/ci/sphinxcontrib-devhelp_1611920923094/work
215 | sphinxcontrib-htmlhelp @ file:///home/ktietz/src/ci/sphinxcontrib-htmlhelp_1611920974801/work
216 | sphinxcontrib-jsmath @ file:///home/ktietz/src/ci/sphinxcontrib-jsmath_1611920942228/work
217 | sphinxcontrib-qthelp @ file:///home/ktietz/src/ci/sphinxcontrib-qthelp_1611921055322/work
218 | sphinxcontrib-serializinghtml @ file:///home/ktietz/src/ci/sphinxcontrib-serializinghtml_1611920755253/work
219 | sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work
220 | spyder @ file:///tmp/build/80754af9/spyder_1618327905127/work
221 | spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1617396566288/work
222 | SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1618089170652/work
223 | statsmodels @ file:///tmp/build/80754af9/statsmodels_1614023746358/work
224 | sympy @ file:///tmp/build/80754af9/sympy_1618252284338/work
225 | tables==3.6.1
226 | tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work
227 | tensorboardX==2.2
228 | terminado==0.9.4
229 | testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work
230 | text-unidecode==1.3
231 | textdistance @ file:///tmp/build/80754af9/textdistance_1612461398012/work
232 | threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl
233 | three-merge @ file:///tmp/build/80754af9/three-merge_1607553261110/work
234 | tinycss @ file:///tmp/build/80754af9/tinycss_1617713798712/work
235 | toml @ file:///tmp/build/80754af9/toml_1616166611790/work
236 | toolz @ file:///home/linux1/recipes/ci/toolz_1610987900194/work
237 | torch==1.7.0
238 | torch-cluster==1.5.9
239 | torch-geometric==1.7.0
240 | torch-scatter==2.0.6
241 | torch-sparse==0.6.9
242 | torch-spline-conv==1.2.1
243 | torchaudio==0.7.0a0+ac17b64
244 | torchvision==0.8.0
245 | tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work
246 | tqdm @ file:///tmp/build/80754af9/tqdm_1615925068909/work
247 | traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work
248 | tsBNgen==1.0.0
249 | typed-ast @ file:///tmp/build/80754af9/typed-ast_1610484547928/work
250 | typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
251 | ujson @ file:///tmp/build/80754af9/ujson_1611259522456/work
252 | unicodecsv==0.14.1
253 | Unidecode @ file:///tmp/build/80754af9/unidecode_1614712377438/work
254 | urllib3 @ file:///tmp/build/80754af9/urllib3_1615837158687/work
255 | watchdog @ file:///tmp/build/80754af9/watchdog_1612471027849/work
256 | wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
257 | webencodings==0.5.1
258 | Werkzeug @ file:///home/ktietz/src/ci/werkzeug_1611932622770/work
259 | whichcraft @ file:///tmp/build/80754af9/whichcraft_1617751293875/work
260 | widgetsnbextension==3.5.1
261 | wrapt==1.12.1
262 | wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1617224664226/work
263 | xlrd @ file:///tmp/build/80754af9/xlrd_1608072521494/work
264 | XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1617224712951/work
265 | xlwt==1.3.0
266 | yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work
267 | zict==2.0.0
268 | zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work
269 | zope.event==4.5.0
270 | zope.interface @ file:///tmp/build/80754af9/zope.interface_1616357211867/work
271 |
--------------------------------------------------------------------------------