├── examples.png
├── pipeline.png
├── __pycache__
├── gGAN.cpython-36.pyc
└── gGAN.cpython-37.pyc
├── README.md
├── demo.py
└── gGAN.py
/examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basiralab/gGAN/HEAD/examples.png
--------------------------------------------------------------------------------
/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basiralab/gGAN/HEAD/pipeline.png
--------------------------------------------------------------------------------
/__pycache__/gGAN.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basiralab/gGAN/HEAD/__pycache__/gGAN.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/gGAN.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basiralab/gGAN/HEAD/__pycache__/gGAN.cpython-37.pyc
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # gGAN-PY (graph-based Generative Adversarial Network for normalizing brain graphs with respect to a fixed template) in Python
2 | gGAN-PY (graph-based Generative Adversarial Network) framework for normalizing brain graphs with respect to a fixed template, coded up in Python
3 | by Zeynep Gürler and Ahmed Nebli. Please contact zeynepgurler1998@gmail.com for inquiries. Thanks.
4 |
5 | > **Foreseeing Brain Graph Evolution Over Time
6 | Using Deep Adversarial Network Normalizer**
7 | > [Zeynep Gürler](https://github.com/zeynepgurler)1, [Ahmed Nebli](https://github.com/ahmednebli)1,2, [Islem Rekik](https://basira-lab.com/)1
8 | > 1BASIRA Lab, Faculty of Computer and Informatics, Istanbul Technical University, Istanbul, Turkey
9 | > 2National School for Computer Science (ENSI), Mannouba, Tunisia
10 | >
11 | > **Abstract:** *Foreseeing the brain
12 | evolution as a complex highly interconnected system, widely modeled as a graph,
13 | is crucial for mapping dynamic interactions between different anatomical regions
14 | of interest (ROIs) in health and disease. Interestingly, brain graph evolution
15 | models remain almost absent in the literature. Here we design an adversarial brain
16 | network normalizer for representing each brain network as a transformation of a
17 | fixed centered population-driven connectional template. Such graph normalization
18 | with respect to a fixed reference paves the way for reliably identifying the most
19 | similar training samples (i.e., brain graphs) to the testing sample at baseline
20 | timepoint. The testing evolution trajectory will be then spanned by the selected
21 | training graphs and their corresponding evolution trajectories. We base our prediction
22 | framework on geometric deep learning which naturally operates on graphs and nicely preserves
23 | their topological properties. Specifically, we propose the first graph-based
24 | Generative Adversarial Network (gGAN) that not only learns how to normalize brain
25 | graphs with respect to a fixed connectional brain template (CBT) (i.e., a brain
26 | template that selectively captures the most common features across a brain population)
27 | but also learns a highorder representation of the brain graphs also called embeddings. We use these embeddings to compute the similarity between training and testing
28 | subjects which allows us to pick the closest training subjects at baseline timepoint to predict the evolution of the testing brain graph over time. A series of benchmarks against several comparison methods showed that our proposed method achieved the
29 | lowest brain disease evolution prediction error using a single baseline timepoint.
30 |
31 |
32 | # Detailed proposed framework pipeline
33 | This work has been published in the Journal of workshop PRIME at MICCAI, 2020. Our framework is a brain graph evolution trajectory prediction framework based on a gGAN architecture comprising a normalizer network with respect to a fixed connectional brain template (CBT). Our learning-based framework comprises four key steps. (1) Learning to normalize brain graphs with respect to the CBT, (2) Embedding the training, testing graphs and the CBT, (3) Brain graph evolution prediction using top k-closest neighbor selection. Experimental results against comparison methods demonstrate that our framework can achieve the best results in terms of average mean absolute error (MAE). We evaluated our proposed framework from OASIS-2 preprocessed dataset (https://www.oasis-brains.org/).
34 |
35 | More details can be found at: (link to the paper) and our research paper video on the BASIRA Lab YouTube channel (link).
36 |
37 | 
38 |
39 |
40 | # Libraries to preinstall in Python
41 | * [Python 3.8](https://www.python.org/)
42 | * [PyTorch 1.4.0](http://pytorch.org/)
43 | * [Torch-geometric](https://github.com/rusty1s/pytorch_geometric)
44 | * [Torch-sparse](https://github.com/rusty1s/pytorch_sparse)
45 | * [Torch-scatter](https://github.com/rusty1s/pytorch_scatter)
46 | * [Scikit-learn 0.23.0+](https://scikit-learn.org/stable/)
47 | * [Matplotlib 3.1.3+](https://matplotlib.org/)
48 | * [Numpy 1.18.1+](https://numpy.org/)
49 |
50 | # Demo
51 |
52 | gGAN is coded in Python 3.8 on Windows 10. GPU is not needed to run the code.
53 | This code has been slightly modified to be compatible across all PyTorch versions.
54 | demo.py is the implementation of the brain graph evolution trajectory framework that proposed
55 | by Foreseeing Brain Graph Evolution Over Time Using Deep Adversarial Network
56 | Normalizer paper. In order to use just the brain graph normalizer (gGAN), you can run gGAN.py.
57 | In this repo, we release the gGAN source code trained and tested on a simulated
58 | data as shown below:
59 |
60 | **Data preparation**
61 |
62 | We simulated random graph dataset drawn from two Gaussian distributions using the function np.random.normal.
63 | Number of subjects, number of regions, number of epochs and number of folds are manually
64 | inputted by the user when starting the demo.
65 |
66 | To train and evaluate gGAN code on other datasets, you need to provide:
67 |
68 | • A tensor of size (n × m × m) stacking the symmetric matrices of the training subjects.
69 | n denotes the total number of subjects and m denotes the number of regions.
70 |
71 | The demo outputs are:
72 |
73 | • A matrix of size (t × l × (m × m)) stacking the predicted features of the testing subjects.
74 | t denotes the total number of testing subjects, l denotes the number of varying k numbers.
75 |
76 | **Train and test gGAN**
77 |
78 | To evaluate our framework, we used leave-one-out cross validation strategy.
79 |
80 |
81 | # Python Code
82 | To run gGAN, generate a fixed connectional brain template. Use netNorm: https://github.com/basiralab/netNorm-PY
83 |
84 | # Example Results
85 | If you set the number of epochs as 500, number of subjects as 90 and number of regions as 35, you will approximately get the following outputs when running the demo with default parameter setting:
86 |
87 | 
88 |
89 |
90 | # YouTube videos to install and run the code and understand how gGAN works
91 |
92 | To install and run our prediction framework, check the following YouTube video:
93 | https://youtu.be/2zKle7GzrIM
94 |
95 | To learn about how our architecture works, check the following YouTube video:
96 | https://youtu.be/5vpQIFzf2Go
97 |
98 | # Related References
99 | Fast Representation Learning with Pytorch-geometric: Fey, Matthias, Lenssen, Jan E., 2019, ICLR Workshop on Representation Learning on Graphs and Manifolds
100 |
101 | Network Normalization for Integrating Multi-view Networks (netNorm): Dhifallah, S., Rekik, I., 2020, Estimation of connectional brain templates using selective multi-view network normalization
102 |
103 | # arXiv link
104 |
105 | You can download our paper at: https://arxiv.org/abs/2009.11166
106 |
107 | # Please Cite the Following paper when using gGAN:
108 |
109 | @article{gurler2020, title={ Foreseeing Brain Graph Evolution Over Time
110 | Using Deep Adversarial Network Normalizer},
111 | author={Gurler Zeynep, Nebli Ahmed, Rekik Islem},
112 | journal={Predictive Intelligence in Medicine International Society and Conference Series on Medical Image Computing and Computer-Assisted Intervention},
113 | volume={},
114 | pages={},
115 | year={2020},
116 | publisher={Springer}
117 | }
118 |
119 |
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pdb
4 | import numpy as np
5 | import math
6 | import itertools
7 | import torch
8 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout
9 | from sklearn.preprocessing import MinMaxScaler
10 | from sklearn import preprocessing
11 | from torch_geometric.data import Data
12 | from torch.autograd import Variable
13 | import torch.nn.functional as F
14 | import torch.nn as nn
15 | from torch_geometric.nn import NNConv
16 | from torch_geometric.nn import BatchNorm, EdgePooling, TopKPooling, global_add_pool
17 | from sklearn.model_selection import KFold
18 | from sklearn.cluster import KMeans
19 | import matplotlib.pyplot as plt
20 | import scipy.io
21 | import scipy.stats as stats
22 | import pandas as pd
23 | import seaborn as sns
24 | import random
25 | from gGAN import gGAN, netNorm
26 |
27 | torch.cuda.empty_cache()
28 | torch.cuda.empty_cache()
29 |
30 | # random seed
31 | manualSeed = 1
32 |
33 | np.random.seed(manualSeed)
34 | random.seed(manualSeed)
35 | torch.manual_seed(manualSeed)
36 |
37 | if torch.cuda.is_available():
38 | device = torch.device('cuda')
39 | print('running on GPU')
40 | # if you are using GPU
41 | torch.cuda.manual_seed(manualSeed)
42 | torch.cuda.manual_seed_all(manualSeed)
43 |
44 | torch.backends.cudnn.enabled = False
45 | torch.backends.cudnn.benchmark = False
46 | torch.backends.cudnn.deterministic = True
47 |
48 | else:
49 | device = torch.device("cpu")
50 | print('running on CPU')
51 |
52 |
53 | def demo():
54 | def cast_data(array_of_tensors, version):
55 | version1 = torch.tensor(version, dtype=torch.int)
56 |
57 | N_ROI = array_of_tensors[0].shape[0]
58 | CHANNELS = 1
59 | dataset = []
60 | edge_index = torch.zeros(2, N_ROI * N_ROI)
61 | edge_attr = torch.zeros(N_ROI * N_ROI, CHANNELS)
62 | x = torch.zeros((N_ROI, N_ROI)) # 35 x 35
63 | y = torch.zeros((1,))
64 |
65 | counter = 0
66 | for i in range(N_ROI):
67 | for j in range(N_ROI):
68 | edge_index[:, counter] = torch.tensor([i, j])
69 | counter += 1
70 | for mat in array_of_tensors: # 1,35,35,4
71 |
72 | if version1 == 0:
73 | edge_attr = mat.view(1225, 1)
74 | x = mat.view(nbr_of_regions, nbr_of_regions)
75 | edge_index = torch.tensor(edge_index, dtype=torch.long)
76 | edge_attr = torch.tensor(edge_attr, dtype=torch.float)
77 | x = torch.tensor(x, dtype=torch.float)
78 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
79 | dataset.append(data)
80 |
81 | elif version1 == 1:
82 | edge_attr = torch.randn(N_ROI * N_ROI, CHANNELS)
83 | x = torch.randn(N_ROI, N_ROI) # 35 x 35
84 | edge_index = torch.tensor(edge_index, dtype=torch.long)
85 | edge_attr = torch.tensor(edge_attr, dtype=torch.float)
86 | x = torch.tensor(x, dtype=torch.float)
87 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
88 | dataset.append(data)
89 |
90 | return dataset
91 |
92 | #####################################################################################################
93 |
94 | def linear_features(data):
95 | n_roi = data[0].shape[0]
96 | n_sub = data.shape[0]
97 | counter = 0
98 |
99 | num_feat = (n_roi * (n_roi - 1) // 2)
100 | final_data = np.empty([n_sub, num_feat], dtype=float)
101 | for k in range(n_sub):
102 | for i in range(n_roi):
103 | for j in range(i+1, n_roi):
104 | final_data[k, counter] = data[k, i, j]
105 | counter += 1
106 | counter = 0
107 |
108 | return final_data
109 |
110 | def make_sym_matrix(nbr_of_regions, feature_vector):
111 | sym_matrix = np.zeros([9, feature_vector.shape[1], nbr_of_regions, nbr_of_regions], dtype=np.double)
112 | for j in range(9):
113 | for i in range(feature_vector.shape[1]):
114 | my_matrix = np.zeros([nbr_of_regions, nbr_of_regions], dtype=np.double)
115 |
116 | my_matrix[np.triu_indices(nbr_of_regions, k=1)] = feature_vector[j, i, :]
117 | my_matrix = my_matrix + my_matrix.T
118 | my_matrix[np.diag_indices(nbr_of_regions)] = 0
119 | sym_matrix[j, i,:,:] = my_matrix
120 |
121 | return sym_matrix
122 |
123 | def plot_predictions(predicted, fold):
124 | plt.clf()
125 | for j in range(predicted.shape[0]):
126 | for i in range(predicted.shape[1]):
127 | predicted_sub = predicted[j, i, :, :]
128 | plt.pcolor(abs(predicted_sub))
129 | if(j == 0 and i == 0):
130 | plt.colorbar()
131 | plt.imshow(predicted_sub)
132 | plt.savefig('./plot/img' + str(fold) + str(j) + str(i) + '.png')
133 |
134 | def plot_MAE(prediction, data_next, test, fold):
135 | # mae
136 | MAE = np.zeros((9), dtype=np.double)
137 | for i in range(9):
138 | MAE_i = abs(prediction[i, :, :] - data_next[test])
139 | MAE[i] = np.mean(MAE_i)
140 |
141 | plt.clf()
142 | k = ['k=2', 'k=3', 'k=4', 'k=5', 'k=6', 'k=7', 'k=8', 'k=9', 'k=10']
143 | sns.set(style="whitegrid")
144 |
145 | df = pd.DataFrame(dict(x=k, y=MAE))
146 | # total = sns.load_dataset('tips')
147 | ax = sns.barplot(x="x", y="y", data=df)
148 | min = MAE.min() - 0.01
149 | max = MAE.max() + 0.01
150 | ax.set(ylim=(min, max))
151 | plt.savefig('./plot/mae' + str(fold) + '.png')
152 |
153 | ######################################################################################################################################
154 |
155 | class Generator(nn.Module):
156 | def __init__(self):
157 | super(Generator, self).__init__()
158 |
159 | nn = Sequential(Linear(1, 1225), ReLU())
160 | self.conv1 = NNConv(35, 35, nn, aggr='mean', root_weight=True, bias=True)
161 | self.conv11 = BatchNorm(35, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
162 |
163 | nn = Sequential(Linear(1, 35), ReLU())
164 | self.conv2 = NNConv(35, 1, nn, aggr='mean', root_weight=True, bias=True)
165 | self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
166 |
167 | nn = Sequential(Linear(1, 35), ReLU())
168 | self.conv3 = NNConv(1, 35, nn, aggr='mean', root_weight=True, bias=True)
169 | self.conv33 = BatchNorm(35, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
170 |
171 |
172 |
173 | def forward(self, data):
174 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
175 |
176 | x1 = F.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr)))
177 | x1 = F.dropout(x1, training=self.training)
178 |
179 | x2 = F.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr)))
180 | x2 = F.dropout(x2, training=self.training)
181 |
182 | embedded = x2.detach().cpu().clone().numpy()
183 |
184 | return embedded
185 |
186 | def embed(Casted_source):
187 | embedded_data = np.zeros((1, 35), dtype=float)
188 | i = 0
189 | for data_A in Casted_source: ## take a subject from source and target data
190 | embedded = generator(data_A) # 35 x35
191 |
192 | if i == 0:
193 | embedded = np.transpose(embedded)
194 | embedded_data = embedded
195 | else:
196 | embedded = np.transpose(embedded)
197 | embedded_data = np.append(embedded_data, embedded, axis=0)
198 | i = i + 1
199 | return embedded_data
200 |
201 | def test_gGAN(data_next, embedded_train_data, embedded_test_data, embedded_CBT):
202 | def x_to_x(x_train, x_test, nbr_of_trn, nbr_of_tst):
203 | result = np.empty((nbr_of_tst, nbr_of_trn), dtype=float)
204 | for i in range(nbr_of_tst):
205 | x_t = np.transpose(x_test[i])
206 | for j in range(nbr_of_trn):
207 | result[i, j] = np.matmul(x_train[j], x_t)
208 | return result
209 |
210 | def check(neighbors, i, j):
211 | for val in neighbors[i, :]:
212 | if val == j:
213 | return 1
214 | return 0
215 |
216 | def k_neighbors(x_to_x, k_num, nbr_of_trn, nbr_of_tst):
217 | neighbors = np.zeros((nbr_of_tst, k_num), dtype=int)
218 | used = np.zeros((nbr_of_tst, nbr_of_trn), dtype=int)
219 | current = 0
220 | for i in range(nbr_of_tst):
221 | for k in range(k_num):
222 | for j in range(nbr_of_trn):
223 | if abs(x_to_x[i, j]) > current:
224 | if check(neighbors, i, j) == 0:
225 | neighbors[i, k] = j
226 | current = abs(x_to_x[i, neighbors[i, k]])
227 | current = 0
228 |
229 | return neighbors
230 |
231 | def subtract_cbt(x, cbt, length):
232 | for i in range(length):
233 | x[i] = abs(x[i] - cbt[0])
234 |
235 | return x
236 |
237 | def predict_samples(k_neighbors, t1, nbr_of_tst):
238 | average = np.zeros((nbr_of_tst, 595), dtype=float)
239 | for i in range(nbr_of_tst):
240 | for j in range(len(k_neighbors[0])):
241 | average[i] = average[i] + t1[k_neighbors[i,j],:]
242 |
243 | average[i] = average[i] / len(k_neighbors[0])
244 |
245 | return average
246 |
247 | residual_of_tr_embeddings = subtract_cbt(embedded_train_data, embedded_CBT, len(embedded_train_data))
248 | residual_of_ts_embeddings = subtract_cbt(embedded_test_data, embedded_CBT, len(embedded_test_data))
249 |
250 | dot_of_residuals = x_to_x(residual_of_tr_embeddings, residual_of_ts_embeddings, len(train), len(test))
251 | for k in range(2, 11):
252 | k_neighbors_ = k_neighbors(dot_of_residuals, k, len(train), len(test))
253 |
254 | if k == 2:
255 | prediction = predict_samples(k_neighbors_, data_next, len(embedded_test_data))
256 | prediction = np.reshape(prediction, (1, len(embedded_test_data), nbr_of_feat))
257 | else:
258 | new_predict = predict_samples(k_neighbors_, data_next, len(embedded_test_data))
259 | new_predict = np.reshape(new_predict, (1, len(embedded_test_data), nbr_of_feat))
260 | prediction = np.append(prediction, new_predict, axis=0)
261 |
262 | return prediction
263 |
264 | nbr_of_sub = int(input('Please select the number of subjects: '))
265 | if nbr_of_sub < 5:
266 | print("You can not give less than 5 to the number of subjects. ")
267 | nbr_of_sub = int(input('Please select the number of subjects: '))
268 | nbr_of_sub_for_cbt = int(input('Please select the number of subjects to generate the CBT: '))
269 | nbr_of_regions = int(input('Please select the number of regions: '))
270 | nbr_of_epochs = int(input('Please select the number of epochs: '))
271 | nbr_of_folds = int(input('Please select the number of folds: '))
272 | hyper_param1 = 100
273 | nbr_of_feat = int((np.square(nbr_of_regions) - nbr_of_regions) / 2)
274 |
275 | data = np.random.normal(0.6, 0.3, (nbr_of_sub, nbr_of_regions, nbr_of_regions))
276 | data = np.abs(data)
277 | independent_data = np.random.normal(0.6, 0.3, (nbr_of_sub_for_cbt, nbr_of_regions, nbr_of_regions))
278 | independent_data = np.abs(independent_data)
279 | data_next = np.random.normal(0.4, 0.3, (nbr_of_sub, nbr_of_regions, nbr_of_regions))
280 | data_next = np.abs(data_next)
281 | CBT = netNorm(independent_data, nbr_of_sub_for_cbt, nbr_of_regions)
282 | gGAN(data, nbr_of_regions, nbr_of_epochs, nbr_of_folds, hyper_param1, CBT)
283 |
284 | # embed train and test subjects
285 | kfold = KFold(n_splits=nbr_of_folds, shuffle=True, random_state=manualSeed)
286 |
287 | source_data = torch.from_numpy(data) # convert numpy array to torch tensor
288 | source_data = source_data.type(torch.FloatTensor)
289 |
290 | target_data = np.reshape(CBT, (1, nbr_of_regions, nbr_of_regions, 1))
291 | target_data = torch.from_numpy(target_data) # convert numpy array to torch tensor
292 | target_data = target_data.type(torch.FloatTensor)
293 |
294 | i = 1
295 | for train, test in kfold.split(source_data):
296 | adversarial_loss = torch.nn.BCELoss()
297 | l1_loss = torch.nn.L1Loss()
298 | trained_model_gen = torch.load('./weight_' + str(i) + 'generator_.model')
299 | generator = Generator()
300 | generator.load_state_dict(trained_model_gen)
301 |
302 | train_data = source_data[train]
303 | test_data = source_data[test]
304 |
305 | generator.to(device)
306 | adversarial_loss.to(device)
307 | l1_loss.to(device)
308 |
309 | X_train_casted_source = [d.to(device) for d in cast_data(train_data, 0)]
310 | X_test_casted_source = [d.to(device) for d in cast_data(test_data, 0)]
311 | data_B = [d.to(device) for d in cast_data(target_data, 0)]
312 |
313 | embedded_train_data = embed(X_train_casted_source)
314 | embedded_test_data = embed(X_test_casted_source)
315 | embedded_CBT = embed(data_B)
316 |
317 | if i == 1:
318 | data_next = linear_features(data_next)
319 | predicted_flat = test_gGAN(data_next, embedded_train_data, embedded_test_data, embedded_CBT)
320 |
321 | plot_MAE(predicted_flat, data_next, test, i)
322 | i = i + 1
323 |
324 | predicted = make_sym_matrix(nbr_of_regions, predicted_flat)
325 | plot_predictions(predicted, i - 1)
326 |
327 | demo()
328 |
329 |
--------------------------------------------------------------------------------
/gGAN.py:
--------------------------------------------------------------------------------
1 | """Main function of gGAN for the paper: Foreseeing Brain Graph Evolution Over Time
2 | Using Deep Adversarial Network Normalizer
3 | Details can be found in: (https://arxiv.org/abs/2009.11166)
4 | (1) the original paper .
5 | ---------------------------------------------------------------------
6 | This file contains the implementation of two key steps of our gGAN framework:
7 | netNorm(v, nbr_of_sub, nbr_of_regions)
8 | Inputs:
9 | v: (n × t x t) matrix stacking the source graphs of all subjects
10 | n the total number of subjects
11 | t number of regions
12 | Output:
13 | CBT: (t x t) matrix representing the connectional brain template
14 |
15 | gGAN(sourceGraph, nbr_of_regions, nbr_of_folds, nbr_of_epochs, hyper_param1, CBT)
16 | Inputs:
17 | sourceGraph: (n × t x t) matrix stacking the source graphs of all subjects
18 | n the total number of subjects
19 | t number of regions
20 | CBT: (t x t) matrix stacking the connectional brain template generated by netNorm
21 |
22 | Output:
23 | translatedGraph: (t x t) matrix stacking the graph translated into CBT
24 |
25 | This code has been slightly modified to be compatible across all PyTorch versions.
26 |
27 | (2) Dependencies: please install the following libraries:
28 | - matplotlib
29 | - numpy
30 | - scikitlearn
31 | - pytorch
32 | - pytorch-geometric
33 | - pytorch-scatter
34 | - pytorch-sparse
35 | - scipy
36 |
37 | ---------------------------------------------------------------------
38 | Copyright 2020 ().
39 | Please cite the above paper if you use this code.
40 | All rights reserved.
41 | """
42 |
43 |
44 | # If you are using Google Colab please uncomment the three following lines.
45 | # !pip install torch_geometric
46 | # !pip install torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
47 | # !pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
48 |
49 |
50 | import argparse
51 | import os
52 | import pdb
53 | import numpy as np
54 | import math
55 | import itertools
56 | import torch
57 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout
58 | from sklearn.preprocessing import MinMaxScaler
59 | from sklearn import preprocessing
60 | from torch_geometric.data import Data
61 | from torch.autograd import Variable
62 | import torch.nn.functional as F
63 | import torch.nn as nn
64 | from torch_geometric.nn import NNConv
65 | from torch_geometric.nn import BatchNorm, EdgePooling, TopKPooling, global_add_pool
66 | from sklearn.model_selection import KFold
67 | from sklearn.cluster import KMeans
68 | import matplotlib.pyplot as plt
69 | import scipy.io
70 | import scipy.stats as stats
71 | import random
72 |
73 | import seaborn as sns
74 |
75 | torch.cuda.empty_cache()
76 | torch.cuda.empty_cache()
77 |
78 | # random seed
79 | manualSeed = 1
80 |
81 | np.random.seed(manualSeed)
82 | random.seed(manualSeed)
83 | torch.manual_seed(manualSeed)
84 |
85 | if torch.cuda.is_available():
86 | device = torch.device('cuda')
87 | print('running on GPU')
88 | # if you are using GPU
89 | torch.cuda.manual_seed(manualSeed)
90 | torch.cuda.manual_seed_all(manualSeed)
91 |
92 | torch.backends.cudnn.enabled = False
93 | torch.backends.cudnn.benchmark = False
94 | torch.backends.cudnn.deterministic = True
95 |
96 | else:
97 | device = torch.device("cpu")
98 | print('running on CPU')
99 |
100 | def netNorm(v, nbr_of_sub, nbr_of_regions):
101 | nbr_of_feat = int((np.square(nbr_of_regions) - nbr_of_regions) / 2)
102 |
103 | def upper_triangular():
104 | All_subj = np.zeros((nbr_of_sub, nbr_of_feat))
105 | for j in range(nbr_of_sub):
106 | subj_x = v[j, :, :]
107 | subj_x = np.reshape(subj_x, (nbr_of_regions, nbr_of_regions))
108 | subj_x = subj_x[np.triu_indices(nbr_of_regions, k=1)]
109 | subj_x = np.reshape(subj_x, (1, nbr_of_feat))
110 | All_subj[j, :] = subj_x
111 |
112 | return All_subj
113 |
114 | def distances_inter(All_subj):
115 | theta = 0
116 | distance_vector = np.zeros(1)
117 | distance_vector_final = np.zeros(1)
118 | x = All_subj
119 | for i in range(nbr_of_feat):
120 | ROI_i = x[:, i]
121 | for j in range(nbr_of_sub):
122 | subj_j = ROI_i[j:j+1]
123 |
124 | distance_euclidienne_sub_j_sub_k = 0
125 | for k in range(nbr_of_sub):
126 | if k != j:
127 | subj_k = ROI_i[k:k+1]
128 |
129 | distance_euclidienne_sub_j_sub_k = distance_euclidienne_sub_j_sub_k + np.square(subj_k - subj_j)
130 | theta +=1
131 | if j == 0:
132 | distance_vector = np.sqrt(distance_euclidienne_sub_j_sub_k)
133 | else:
134 | distance_vector = np.concatenate((distance_vector, np.sqrt(distance_euclidienne_sub_j_sub_k)), axis=0)
135 |
136 | distance_vector = np.reshape(distance_vector, (nbr_of_sub, 1))
137 | if i == 0:
138 | distance_vector_final = distance_vector
139 | else:
140 | distance_vector_final = np.concatenate((distance_vector_final, distance_vector), axis=1)
141 |
142 | print(theta)
143 | return distance_vector_final
144 |
145 |
146 | def minimum_distances(distance_vector_final):
147 | x = distance_vector_final
148 |
149 | for i in range(nbr_of_feat):
150 | minimum_sub = x[0, i:i+1]
151 | minimum_sub = float(minimum_sub)
152 | general_minimum = 0
153 | general_minimum = np.array(general_minimum)
154 | for k in range(1, nbr_of_sub):
155 | local_sub = x[k:k+1, i:i+1]
156 | local_sub = float(local_sub)
157 | if local_sub < minimum_sub:
158 | general_minimum = k
159 | general_minimum = np.array(general_minimum)
160 | minimum_sub = local_sub
161 | if i == 0:
162 | final_general_minimum = np.array(general_minimum)
163 | else:
164 | final_general_minimum = np.vstack((final_general_minimum, general_minimum))
165 |
166 | final_general_minimum = np.transpose(final_general_minimum)
167 |
168 | return final_general_minimum
169 |
170 | def new_tensor(final_general_minimum, All_subj):
171 | y = All_subj
172 | x = final_general_minimum
173 | for i in range(nbr_of_feat):
174 | optimal_subj = x[:, i:i+1]
175 | optimal_subj = np.reshape(optimal_subj, (1))
176 | optimal_subj = int(optimal_subj)
177 | if i == 0:
178 | final_new_tensor = y[optimal_subj: optimal_subj+1, i:i+1]
179 | else:
180 | final_new_tensor = np.concatenate((final_new_tensor, y[optimal_subj: optimal_subj+1, i:i+1]), axis=1)
181 |
182 | return final_new_tensor
183 |
184 | def make_sym_matrix(nbr_of_regions, feature_vector):
185 | my_matrix = np.zeros([nbr_of_regions, nbr_of_regions], dtype=np.double)
186 |
187 | my_matrix[np.triu_indices(nbr_of_regions, k=1)] = feature_vector
188 | my_matrix = my_matrix + my_matrix.T
189 | my_matrix[np.diag_indices(nbr_of_regions)] = 0
190 |
191 | return my_matrix
192 |
193 | def re_make_tensor(final_new_tensor, nbr_of_regions):
194 | x = final_new_tensor
195 | #x = np.reshape(x, (nbr_of_views, nbr_of_feat))
196 |
197 | x = make_sym_matrix(nbr_of_regions, x)
198 | x = np.reshape(x, (1, nbr_of_regions, nbr_of_regions))
199 |
200 | return x
201 |
202 | Upp_trig = upper_triangular()
203 | Dis_int = distances_inter(Upp_trig)
204 | Min_dis = minimum_distances(Dis_int)
205 | New_ten = new_tensor(Min_dis, Upp_trig)
206 | Re_ten = re_make_tensor(New_ten, nbr_of_regions)
207 | Re_ten = np.reshape(Re_ten, (nbr_of_regions, nbr_of_regions))
208 | np.fill_diagonal(Re_ten, 0)
209 | network = np.array(Re_ten)
210 | return network
211 |
212 | def gGAN(data, nbr_of_regions, nbr_of_epochs, nbr_of_folds, hyper_param1, CBT):
213 | def cast_data(array_of_tensors, version):
214 | version1 = torch.tensor(version, dtype=torch.int)
215 |
216 | N_ROI = array_of_tensors[0].shape[0]
217 | CHANNELS = 1
218 | dataset = []
219 | edge_index = torch.zeros(2, N_ROI * N_ROI)
220 | edge_attr = torch.zeros(N_ROI * N_ROI, CHANNELS)
221 | x = torch.zeros((N_ROI, N_ROI)) # 35 x 35
222 | y = torch.zeros((1,))
223 |
224 | counter = 0
225 | for i in range(N_ROI):
226 | for j in range(N_ROI):
227 | edge_index[:, counter] = torch.tensor([i, j])
228 | counter += 1
229 | for mat in array_of_tensors: #1,35,35,4
230 |
231 | if version1 == 0:
232 | edge_attr = mat.view((nbr_of_regions*nbr_of_regions), 1)
233 | x = mat.view(nbr_of_regions, nbr_of_regions)
234 | edge_index = torch.tensor(edge_index, dtype=torch.long)
235 | edge_attr = torch.tensor(edge_attr, dtype=torch.float)
236 | x = torch.tensor(x, dtype=torch.float)
237 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
238 | dataset.append(data)
239 |
240 | elif version1 == 1:
241 | edge_attr = torch.randn(N_ROI * N_ROI, CHANNELS)
242 | x = torch.randn(N_ROI, N_ROI) # 35 x 35
243 | edge_index = torch.tensor(edge_index, dtype=torch.long)
244 | edge_attr = torch.tensor(edge_attr, dtype=torch.float)
245 | x = torch.tensor(x, dtype=torch.float)
246 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
247 | dataset.append(data)
248 |
249 | return dataset
250 |
251 | # ------------------------------------------------------------
252 |
253 | def plotting_loss(losses_generator, losses_discriminator, epoch):
254 | plt.figure(1)
255 | plt.plot(epoch, losses_generator, 'r-')
256 | plt.plot(epoch, losses_discriminator, 'b-')
257 | plt.legend(['G Loss', 'D Loss'])
258 | plt.xlabel('Epoch')
259 | plt.ylabel('Loss')
260 | plt.savefig('./plot/loss' + str(epoch) + '.png')
261 |
262 | # -------------------------------------------------------------
263 |
264 | class Generator(nn.Module):
265 | def __init__(self):
266 | super(Generator, self).__init__()
267 |
268 | nn = Sequential(Linear(1, (nbr_of_regions*nbr_of_regions)), ReLU())
269 | self.conv1 = NNConv(nbr_of_regions, nbr_of_regions, nn, aggr='mean', root_weight=True, bias=True)
270 | self.conv11 = BatchNorm(nbr_of_regions, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
271 |
272 | nn = Sequential(Linear(1, nbr_of_regions), ReLU())
273 | self.conv2 = NNConv(nbr_of_regions, 1, nn, aggr='mean', root_weight=True, bias=True)
274 | self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
275 |
276 | nn = Sequential(Linear(1, nbr_of_regions), ReLU())
277 | self.conv3 = NNConv(1, nbr_of_regions, nn, aggr='mean', root_weight=True, bias=True)
278 | self.conv33 = BatchNorm(nbr_of_regions, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
279 |
280 | def forward(self, data):
281 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
282 |
283 | x1 = F.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr)))
284 | x1 = F.dropout(x1, training=self.training)
285 |
286 | x2 = F.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr)))
287 | x2 = F.dropout(x2, training=self.training)
288 |
289 | x3 = torch.cat([F.sigmoid(self.conv33(self.conv3(x2, edge_index, edge_attr))), x1], dim=1)
290 | x4 = x3[:, 0:nbr_of_regions]
291 | x5 = x3[:, nbr_of_regions:2*nbr_of_regions]
292 |
293 | x6 = (x4 + x5) / 2
294 | return x6
295 |
296 | class Discriminator1(torch.nn.Module):
297 | def __init__(self):
298 | super(Discriminator1, self).__init__()
299 | nn = Sequential(Linear(2, (nbr_of_regions*nbr_of_regions)), ReLU())
300 | self.conv1 = NNConv(nbr_of_regions, nbr_of_regions, nn, aggr='mean', root_weight=True, bias=True)
301 | self.conv11 = BatchNorm(nbr_of_regions, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
302 |
303 | nn = Sequential(Linear(2, nbr_of_regions), ReLU())
304 | self.conv2 = NNConv(nbr_of_regions, 1, nn, aggr='mean', root_weight=True, bias=True)
305 | self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
306 |
307 |
308 | def forward(self, data, data_to_translate):
309 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
310 | edge_attr_data_to_translate = data_to_translate.edge_attr
311 |
312 | edge_attr_data_to_translate_reshaped = edge_attr_data_to_translate.view(nbr_of_regions*nbr_of_regions, 1)
313 |
314 | gen_input = torch.cat((edge_attr, edge_attr_data_to_translate_reshaped), -1)
315 | x = F.relu(self.conv11(self.conv1(x, edge_index, gen_input)))
316 | x = F.dropout(x, training=self.training)
317 | x = F.relu(self.conv22(self.conv2(x, edge_index, gen_input)))
318 |
319 | return F.sigmoid(x)
320 |
321 | # ----------------------------------------
322 | # Training
323 | # ----------------------------------------
324 |
325 | n_fold_counter = 1
326 | plot_loss_g = np.empty((nbr_of_epochs), dtype=float)
327 | plot_loss_d = np.empty((nbr_of_epochs), dtype=float)
328 |
329 | kfold = KFold(n_splits=nbr_of_folds, shuffle=True, random_state=manualSeed)
330 |
331 | source_data = torch.from_numpy(data) # convert numpy array to torch tensor
332 | source_data = source_data.type(torch.FloatTensor)
333 |
334 | target_data = np.reshape(CBT, (1, nbr_of_regions, nbr_of_regions, 1))
335 | target_data = torch.from_numpy(target_data) # convert numpy array to torch tensor
336 | target_data = target_data.type(torch.FloatTensor)
337 |
338 | for train, test in kfold.split(source_data):
339 | # Loss function
340 | adversarial_loss = torch.nn.BCELoss()
341 | l1_loss = torch.nn.L1Loss()
342 | # Initialize generator and discriminator
343 | generator = Generator()
344 | discriminator1 = Discriminator1()
345 |
346 | generator.to(device)
347 | discriminator1.to(device)
348 | adversarial_loss.to(device)
349 | l1_loss.to(device)
350 |
351 | # Optimizers
352 | optimizer_G = torch.optim.AdamW(generator.parameters(), lr=0.005, betas=(0.5, 0.999))
353 | optimizer_D = torch.optim.AdamW(discriminator1.parameters(), lr=0.01, betas=(0.5, 0.999))
354 |
355 | # ------------------------------- select source data and target data -------------------------------
356 |
357 | train_source, test_source = source_data[train], source_data[test] ## from a specific source view
358 |
359 | # 1: everything random; 0: everything is the matrix in question
360 |
361 | train_casted_source = [d.to(device) for d in cast_data(train_source, 0)]
362 | train_casted_target = [d.to(device) for d in cast_data(target_data, 0)]
363 |
364 | for epoch in range(nbr_of_epochs):
365 | # Train Generator
366 | with torch.autograd.set_detect_anomaly(True):
367 |
368 | losses_generator = []
369 | losses_discriminator = []
370 |
371 | for data_A in train_casted_source:
372 | generators_output_ = generator(data_A) # 35 x35
373 | generators_output = generators_output_.view(1, nbr_of_regions, nbr_of_regions, 1).type(torch.FloatTensor)
374 |
375 | generators_output_casted = [d.to(device) for d in cast_data(generators_output, 0)]
376 | for (data_discriminator) in generators_output_casted:
377 | discriminator_output_of_gen = discriminator1(data_discriminator, data_A).to(device)
378 |
379 | g_loss_adversarial = adversarial_loss(discriminator_output_of_gen, torch.ones_like(discriminator_output_of_gen))
380 |
381 | g_loss_pix2pix = l1_loss(generators_output_, train_casted_target[0].edge_attr.view(nbr_of_regions, nbr_of_regions))
382 |
383 | g_loss = g_loss_adversarial + (hyper_param1 * g_loss_pix2pix)
384 | losses_generator.append(g_loss)
385 |
386 | discriminator_output_for_real_loss = discriminator1(data_A, train_casted_target[0])
387 |
388 | real_loss = adversarial_loss(discriminator_output_for_real_loss,
389 | (torch.ones_like(discriminator_output_for_real_loss, requires_grad=False)))
390 | fake_loss = adversarial_loss(discriminator_output_of_gen.detach(), torch.zeros_like(discriminator_output_of_gen))
391 |
392 | d_loss = (real_loss + fake_loss) / 2
393 | losses_discriminator.append(d_loss)
394 |
395 | optimizer_G.zero_grad()
396 | losses_generator = torch.mean(torch.stack(losses_generator))
397 | losses_generator.backward(retain_graph=True)
398 | optimizer_G.step()
399 |
400 | optimizer_D.zero_grad()
401 | losses_discriminator = torch.mean(torch.stack(losses_discriminator))
402 |
403 | losses_discriminator.backward(retain_graph=True)
404 | optimizer_D.step()
405 |
406 | print(
407 | "[Epoch %d/%d] [D loss: %f] [G loss: %f]"
408 | % (epoch, nbr_of_epochs, losses_discriminator, losses_generator))
409 |
410 | plot_loss_g[epoch] = losses_generator.detach().cpu().clone().numpy()
411 | plot_loss_d[epoch] = losses_discriminator.detach().cpu().clone().numpy()
412 |
413 | torch.save(generator.state_dict(), "./weight_" + str(n_fold_counter) + "generator" + "_" + ".model")
414 | torch.save(discriminator1.state_dict(), "./weight_" + str(n_fold_counter) + "dicriminator" + "_" + ".model")
415 |
416 | interval = range(0, nbr_of_epochs)
417 | plotting_loss(plot_loss_g, plot_loss_d, interval)
418 | n_fold_counter += 1
419 | torch.cuda.empty_cache()
420 | torch.cuda.empty_cache()
421 |
422 |
423 | nbr_of_sub = int(input('Please select the number of subjects: '))
424 | if nbr_of_sub < 5:
425 | print("You can not give less than 5 to the number of subjects. ")
426 | nbr_of_sub = int(input('Please select the number of subjects: '))
427 | nbr_of_sub_for_cbt = int(input('Please select the number of subjects to generate the CBT: '))
428 | nbr_of_regions = int(input('Please select the number of regions: '))
429 | nbr_of_epochs = int(input('Please select the number of epochs: '))
430 | nbr_of_folds = int(input('Please select the number of folds: '))
431 | hyper_param1 = 100
432 | nbr_of_feat = int((np.square(nbr_of_regions) - nbr_of_regions) / 2)
433 |
434 | data = np.random.normal(0.6, 0.3, (nbr_of_sub, nbr_of_regions, nbr_of_regions))
435 | data = np.abs(data)
436 | independent_data = np.random.normal(0.6, 0.3, (nbr_of_sub_for_cbt, nbr_of_regions, nbr_of_regions))
437 | independent_data = np.abs(independent_data)
438 | CBT = netNorm(independent_data, nbr_of_sub_for_cbt, nbr_of_regions)
439 | gGAN(data, nbr_of_regions, nbr_of_epochs, nbr_of_folds, hyper_param1, CBT)
440 |
--------------------------------------------------------------------------------