├── examples.png ├── pipeline.png ├── __pycache__ ├── gGAN.cpython-36.pyc └── gGAN.cpython-37.pyc ├── README.md ├── demo.py └── gGAN.py /examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/gGAN/HEAD/examples.png -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/gGAN/HEAD/pipeline.png -------------------------------------------------------------------------------- /__pycache__/gGAN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/gGAN/HEAD/__pycache__/gGAN.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/gGAN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/gGAN/HEAD/__pycache__/gGAN.cpython-37.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gGAN-PY (graph-based Generative Adversarial Network for normalizing brain graphs with respect to a fixed template) in Python 2 | gGAN-PY (graph-based Generative Adversarial Network) framework for normalizing brain graphs with respect to a fixed template, coded up in Python 3 | by Zeynep Gürler and Ahmed Nebli. Please contact zeynepgurler1998@gmail.com for inquiries. Thanks. 4 | 5 | > **Foreseeing Brain Graph Evolution Over Time 6 | Using Deep Adversarial Network Normalizer** 7 | > [Zeynep Gürler](https://github.com/zeynepgurler)1, [Ahmed Nebli](https://github.com/ahmednebli)1,2, [Islem Rekik](https://basira-lab.com/)1 8 | > 1BASIRA Lab, Faculty of Computer and Informatics, Istanbul Technical University, Istanbul, Turkey 9 | > 2National School for Computer Science (ENSI), Mannouba, Tunisia 10 | > 11 | > **Abstract:** *Foreseeing the brain 12 | evolution as a complex highly interconnected system, widely modeled as a graph, 13 | is crucial for mapping dynamic interactions between different anatomical regions 14 | of interest (ROIs) in health and disease. Interestingly, brain graph evolution 15 | models remain almost absent in the literature. Here we design an adversarial brain 16 | network normalizer for representing each brain network as a transformation of a 17 | fixed centered population-driven connectional template. Such graph normalization 18 | with respect to a fixed reference paves the way for reliably identifying the most 19 | similar training samples (i.e., brain graphs) to the testing sample at baseline 20 | timepoint. The testing evolution trajectory will be then spanned by the selected 21 | training graphs and their corresponding evolution trajectories. We base our prediction 22 | framework on geometric deep learning which naturally operates on graphs and nicely preserves 23 | their topological properties. Specifically, we propose the first graph-based 24 | Generative Adversarial Network (gGAN) that not only learns how to normalize brain 25 | graphs with respect to a fixed connectional brain template (CBT) (i.e., a brain 26 | template that selectively captures the most common features across a brain population) 27 | but also learns a highorder representation of the brain graphs also called embeddings. We use these embeddings to compute the similarity between training and testing 28 | subjects which allows us to pick the closest training subjects at baseline timepoint to predict the evolution of the testing brain graph over time. A series of benchmarks against several comparison methods showed that our proposed method achieved the 29 | lowest brain disease evolution prediction error using a single baseline timepoint. 30 | 31 | 32 | # Detailed proposed framework pipeline 33 | This work has been published in the Journal of workshop PRIME at MICCAI, 2020. Our framework is a brain graph evolution trajectory prediction framework based on a gGAN architecture comprising a normalizer network with respect to a fixed connectional brain template (CBT). Our learning-based framework comprises four key steps. (1) Learning to normalize brain graphs with respect to the CBT, (2) Embedding the training, testing graphs and the CBT, (3) Brain graph evolution prediction using top k-closest neighbor selection. Experimental results against comparison methods demonstrate that our framework can achieve the best results in terms of average mean absolute error (MAE). We evaluated our proposed framework from OASIS-2 preprocessed dataset (https://www.oasis-brains.org/). 34 | 35 | More details can be found at: (link to the paper) and our research paper video on the BASIRA Lab YouTube channel (link). 36 | 37 | ![gGAN pipeline](pipeline.png) 38 | 39 | 40 | # Libraries to preinstall in Python 41 | * [Python 3.8](https://www.python.org/) 42 | * [PyTorch 1.4.0](http://pytorch.org/) 43 | * [Torch-geometric](https://github.com/rusty1s/pytorch_geometric) 44 | * [Torch-sparse](https://github.com/rusty1s/pytorch_sparse) 45 | * [Torch-scatter](https://github.com/rusty1s/pytorch_scatter) 46 | * [Scikit-learn 0.23.0+](https://scikit-learn.org/stable/) 47 | * [Matplotlib 3.1.3+](https://matplotlib.org/) 48 | * [Numpy 1.18.1+](https://numpy.org/) 49 | 50 | # Demo 51 | 52 | gGAN is coded in Python 3.8 on Windows 10. GPU is not needed to run the code. 53 | This code has been slightly modified to be compatible across all PyTorch versions. 54 | demo.py is the implementation of the brain graph evolution trajectory framework that proposed 55 | by Foreseeing Brain Graph Evolution Over Time Using Deep Adversarial Network 56 | Normalizer paper. In order to use just the brain graph normalizer (gGAN), you can run gGAN.py. 57 | In this repo, we release the gGAN source code trained and tested on a simulated 58 | data as shown below: 59 | 60 | **Data preparation** 61 | 62 | We simulated random graph dataset drawn from two Gaussian distributions using the function np.random.normal. 63 | Number of subjects, number of regions, number of epochs and number of folds are manually 64 | inputted by the user when starting the demo. 65 | 66 | To train and evaluate gGAN code on other datasets, you need to provide: 67 | 68 | • A tensor of size (n × m × m) stacking the symmetric matrices of the training subjects. 69 | n denotes the total number of subjects and m denotes the number of regions.
70 | 71 | The demo outputs are: 72 | 73 | • A matrix of size (t × l × (m × m)) stacking the predicted features of the testing subjects. 74 | t denotes the total number of testing subjects, l denotes the number of varying k numbers. 75 | 76 | **Train and test gGAN** 77 | 78 | To evaluate our framework, we used leave-one-out cross validation strategy. 79 | 80 | 81 | # Python Code 82 | To run gGAN, generate a fixed connectional brain template. Use netNorm: https://github.com/basiralab/netNorm-PY 83 | 84 | # Example Results 85 | If you set the number of epochs as 500, number of subjects as 90 and number of regions as 35, you will approximately get the following outputs when running the demo with default parameter setting: 86 | 87 | ![gGAN pipeline](examples.png) 88 | 89 | 90 | # YouTube videos to install and run the code and understand how gGAN works 91 | 92 | To install and run our prediction framework, check the following YouTube video: 93 | https://youtu.be/2zKle7GzrIM 94 | 95 | To learn about how our architecture works, check the following YouTube video: 96 | https://youtu.be/5vpQIFzf2Go 97 | 98 | # Related References 99 | Fast Representation Learning with Pytorch-geometric: Fey, Matthias, Lenssen, Jan E., 2019, ICLR Workshop on Representation Learning on Graphs and Manifolds 100 | 101 | Network Normalization for Integrating Multi-view Networks (netNorm): Dhifallah, S., Rekik, I., 2020, Estimation of connectional brain templates using selective multi-view network normalization 102 | 103 | # arXiv link 104 | 105 | You can download our paper at: https://arxiv.org/abs/2009.11166 106 | 107 | # Please Cite the Following paper when using gGAN: 108 | 109 | @article{gurler2020, title={ Foreseeing Brain Graph Evolution Over Time 110 | Using Deep Adversarial Network Normalizer},
111 | author={Gurler Zeynep, Nebli Ahmed, Rekik Islem},
112 | journal={Predictive Intelligence in Medicine International Society and Conference Series on Medical Image Computing and Computer-Assisted Intervention}, 113 | volume={},
114 | pages={},
115 | year={2020},
116 | publisher={Springer}
117 | }
118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pdb 4 | import numpy as np 5 | import math 6 | import itertools 7 | import torch 8 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout 9 | from sklearn.preprocessing import MinMaxScaler 10 | from sklearn import preprocessing 11 | from torch_geometric.data import Data 12 | from torch.autograd import Variable 13 | import torch.nn.functional as F 14 | import torch.nn as nn 15 | from torch_geometric.nn import NNConv 16 | from torch_geometric.nn import BatchNorm, EdgePooling, TopKPooling, global_add_pool 17 | from sklearn.model_selection import KFold 18 | from sklearn.cluster import KMeans 19 | import matplotlib.pyplot as plt 20 | import scipy.io 21 | import scipy.stats as stats 22 | import pandas as pd 23 | import seaborn as sns 24 | import random 25 | from gGAN import gGAN, netNorm 26 | 27 | torch.cuda.empty_cache() 28 | torch.cuda.empty_cache() 29 | 30 | # random seed 31 | manualSeed = 1 32 | 33 | np.random.seed(manualSeed) 34 | random.seed(manualSeed) 35 | torch.manual_seed(manualSeed) 36 | 37 | if torch.cuda.is_available(): 38 | device = torch.device('cuda') 39 | print('running on GPU') 40 | # if you are using GPU 41 | torch.cuda.manual_seed(manualSeed) 42 | torch.cuda.manual_seed_all(manualSeed) 43 | 44 | torch.backends.cudnn.enabled = False 45 | torch.backends.cudnn.benchmark = False 46 | torch.backends.cudnn.deterministic = True 47 | 48 | else: 49 | device = torch.device("cpu") 50 | print('running on CPU') 51 | 52 | 53 | def demo(): 54 | def cast_data(array_of_tensors, version): 55 | version1 = torch.tensor(version, dtype=torch.int) 56 | 57 | N_ROI = array_of_tensors[0].shape[0] 58 | CHANNELS = 1 59 | dataset = [] 60 | edge_index = torch.zeros(2, N_ROI * N_ROI) 61 | edge_attr = torch.zeros(N_ROI * N_ROI, CHANNELS) 62 | x = torch.zeros((N_ROI, N_ROI)) # 35 x 35 63 | y = torch.zeros((1,)) 64 | 65 | counter = 0 66 | for i in range(N_ROI): 67 | for j in range(N_ROI): 68 | edge_index[:, counter] = torch.tensor([i, j]) 69 | counter += 1 70 | for mat in array_of_tensors: # 1,35,35,4 71 | 72 | if version1 == 0: 73 | edge_attr = mat.view(1225, 1) 74 | x = mat.view(nbr_of_regions, nbr_of_regions) 75 | edge_index = torch.tensor(edge_index, dtype=torch.long) 76 | edge_attr = torch.tensor(edge_attr, dtype=torch.float) 77 | x = torch.tensor(x, dtype=torch.float) 78 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) 79 | dataset.append(data) 80 | 81 | elif version1 == 1: 82 | edge_attr = torch.randn(N_ROI * N_ROI, CHANNELS) 83 | x = torch.randn(N_ROI, N_ROI) # 35 x 35 84 | edge_index = torch.tensor(edge_index, dtype=torch.long) 85 | edge_attr = torch.tensor(edge_attr, dtype=torch.float) 86 | x = torch.tensor(x, dtype=torch.float) 87 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) 88 | dataset.append(data) 89 | 90 | return dataset 91 | 92 | ##################################################################################################### 93 | 94 | def linear_features(data): 95 | n_roi = data[0].shape[0] 96 | n_sub = data.shape[0] 97 | counter = 0 98 | 99 | num_feat = (n_roi * (n_roi - 1) // 2) 100 | final_data = np.empty([n_sub, num_feat], dtype=float) 101 | for k in range(n_sub): 102 | for i in range(n_roi): 103 | for j in range(i+1, n_roi): 104 | final_data[k, counter] = data[k, i, j] 105 | counter += 1 106 | counter = 0 107 | 108 | return final_data 109 | 110 | def make_sym_matrix(nbr_of_regions, feature_vector): 111 | sym_matrix = np.zeros([9, feature_vector.shape[1], nbr_of_regions, nbr_of_regions], dtype=np.double) 112 | for j in range(9): 113 | for i in range(feature_vector.shape[1]): 114 | my_matrix = np.zeros([nbr_of_regions, nbr_of_regions], dtype=np.double) 115 | 116 | my_matrix[np.triu_indices(nbr_of_regions, k=1)] = feature_vector[j, i, :] 117 | my_matrix = my_matrix + my_matrix.T 118 | my_matrix[np.diag_indices(nbr_of_regions)] = 0 119 | sym_matrix[j, i,:,:] = my_matrix 120 | 121 | return sym_matrix 122 | 123 | def plot_predictions(predicted, fold): 124 | plt.clf() 125 | for j in range(predicted.shape[0]): 126 | for i in range(predicted.shape[1]): 127 | predicted_sub = predicted[j, i, :, :] 128 | plt.pcolor(abs(predicted_sub)) 129 | if(j == 0 and i == 0): 130 | plt.colorbar() 131 | plt.imshow(predicted_sub) 132 | plt.savefig('./plot/img' + str(fold) + str(j) + str(i) + '.png') 133 | 134 | def plot_MAE(prediction, data_next, test, fold): 135 | # mae 136 | MAE = np.zeros((9), dtype=np.double) 137 | for i in range(9): 138 | MAE_i = abs(prediction[i, :, :] - data_next[test]) 139 | MAE[i] = np.mean(MAE_i) 140 | 141 | plt.clf() 142 | k = ['k=2', 'k=3', 'k=4', 'k=5', 'k=6', 'k=7', 'k=8', 'k=9', 'k=10'] 143 | sns.set(style="whitegrid") 144 | 145 | df = pd.DataFrame(dict(x=k, y=MAE)) 146 | # total = sns.load_dataset('tips') 147 | ax = sns.barplot(x="x", y="y", data=df) 148 | min = MAE.min() - 0.01 149 | max = MAE.max() + 0.01 150 | ax.set(ylim=(min, max)) 151 | plt.savefig('./plot/mae' + str(fold) + '.png') 152 | 153 | ###################################################################################################################################### 154 | 155 | class Generator(nn.Module): 156 | def __init__(self): 157 | super(Generator, self).__init__() 158 | 159 | nn = Sequential(Linear(1, 1225), ReLU()) 160 | self.conv1 = NNConv(35, 35, nn, aggr='mean', root_weight=True, bias=True) 161 | self.conv11 = BatchNorm(35, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 162 | 163 | nn = Sequential(Linear(1, 35), ReLU()) 164 | self.conv2 = NNConv(35, 1, nn, aggr='mean', root_weight=True, bias=True) 165 | self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 166 | 167 | nn = Sequential(Linear(1, 35), ReLU()) 168 | self.conv3 = NNConv(1, 35, nn, aggr='mean', root_weight=True, bias=True) 169 | self.conv33 = BatchNorm(35, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 170 | 171 | 172 | 173 | def forward(self, data): 174 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 175 | 176 | x1 = F.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr))) 177 | x1 = F.dropout(x1, training=self.training) 178 | 179 | x2 = F.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr))) 180 | x2 = F.dropout(x2, training=self.training) 181 | 182 | embedded = x2.detach().cpu().clone().numpy() 183 | 184 | return embedded 185 | 186 | def embed(Casted_source): 187 | embedded_data = np.zeros((1, 35), dtype=float) 188 | i = 0 189 | for data_A in Casted_source: ## take a subject from source and target data 190 | embedded = generator(data_A) # 35 x35 191 | 192 | if i == 0: 193 | embedded = np.transpose(embedded) 194 | embedded_data = embedded 195 | else: 196 | embedded = np.transpose(embedded) 197 | embedded_data = np.append(embedded_data, embedded, axis=0) 198 | i = i + 1 199 | return embedded_data 200 | 201 | def test_gGAN(data_next, embedded_train_data, embedded_test_data, embedded_CBT): 202 | def x_to_x(x_train, x_test, nbr_of_trn, nbr_of_tst): 203 | result = np.empty((nbr_of_tst, nbr_of_trn), dtype=float) 204 | for i in range(nbr_of_tst): 205 | x_t = np.transpose(x_test[i]) 206 | for j in range(nbr_of_trn): 207 | result[i, j] = np.matmul(x_train[j], x_t) 208 | return result 209 | 210 | def check(neighbors, i, j): 211 | for val in neighbors[i, :]: 212 | if val == j: 213 | return 1 214 | return 0 215 | 216 | def k_neighbors(x_to_x, k_num, nbr_of_trn, nbr_of_tst): 217 | neighbors = np.zeros((nbr_of_tst, k_num), dtype=int) 218 | used = np.zeros((nbr_of_tst, nbr_of_trn), dtype=int) 219 | current = 0 220 | for i in range(nbr_of_tst): 221 | for k in range(k_num): 222 | for j in range(nbr_of_trn): 223 | if abs(x_to_x[i, j]) > current: 224 | if check(neighbors, i, j) == 0: 225 | neighbors[i, k] = j 226 | current = abs(x_to_x[i, neighbors[i, k]]) 227 | current = 0 228 | 229 | return neighbors 230 | 231 | def subtract_cbt(x, cbt, length): 232 | for i in range(length): 233 | x[i] = abs(x[i] - cbt[0]) 234 | 235 | return x 236 | 237 | def predict_samples(k_neighbors, t1, nbr_of_tst): 238 | average = np.zeros((nbr_of_tst, 595), dtype=float) 239 | for i in range(nbr_of_tst): 240 | for j in range(len(k_neighbors[0])): 241 | average[i] = average[i] + t1[k_neighbors[i,j],:] 242 | 243 | average[i] = average[i] / len(k_neighbors[0]) 244 | 245 | return average 246 | 247 | residual_of_tr_embeddings = subtract_cbt(embedded_train_data, embedded_CBT, len(embedded_train_data)) 248 | residual_of_ts_embeddings = subtract_cbt(embedded_test_data, embedded_CBT, len(embedded_test_data)) 249 | 250 | dot_of_residuals = x_to_x(residual_of_tr_embeddings, residual_of_ts_embeddings, len(train), len(test)) 251 | for k in range(2, 11): 252 | k_neighbors_ = k_neighbors(dot_of_residuals, k, len(train), len(test)) 253 | 254 | if k == 2: 255 | prediction = predict_samples(k_neighbors_, data_next, len(embedded_test_data)) 256 | prediction = np.reshape(prediction, (1, len(embedded_test_data), nbr_of_feat)) 257 | else: 258 | new_predict = predict_samples(k_neighbors_, data_next, len(embedded_test_data)) 259 | new_predict = np.reshape(new_predict, (1, len(embedded_test_data), nbr_of_feat)) 260 | prediction = np.append(prediction, new_predict, axis=0) 261 | 262 | return prediction 263 | 264 | nbr_of_sub = int(input('Please select the number of subjects: ')) 265 | if nbr_of_sub < 5: 266 | print("You can not give less than 5 to the number of subjects. ") 267 | nbr_of_sub = int(input('Please select the number of subjects: ')) 268 | nbr_of_sub_for_cbt = int(input('Please select the number of subjects to generate the CBT: ')) 269 | nbr_of_regions = int(input('Please select the number of regions: ')) 270 | nbr_of_epochs = int(input('Please select the number of epochs: ')) 271 | nbr_of_folds = int(input('Please select the number of folds: ')) 272 | hyper_param1 = 100 273 | nbr_of_feat = int((np.square(nbr_of_regions) - nbr_of_regions) / 2) 274 | 275 | data = np.random.normal(0.6, 0.3, (nbr_of_sub, nbr_of_regions, nbr_of_regions)) 276 | data = np.abs(data) 277 | independent_data = np.random.normal(0.6, 0.3, (nbr_of_sub_for_cbt, nbr_of_regions, nbr_of_regions)) 278 | independent_data = np.abs(independent_data) 279 | data_next = np.random.normal(0.4, 0.3, (nbr_of_sub, nbr_of_regions, nbr_of_regions)) 280 | data_next = np.abs(data_next) 281 | CBT = netNorm(independent_data, nbr_of_sub_for_cbt, nbr_of_regions) 282 | gGAN(data, nbr_of_regions, nbr_of_epochs, nbr_of_folds, hyper_param1, CBT) 283 | 284 | # embed train and test subjects 285 | kfold = KFold(n_splits=nbr_of_folds, shuffle=True, random_state=manualSeed) 286 | 287 | source_data = torch.from_numpy(data) # convert numpy array to torch tensor 288 | source_data = source_data.type(torch.FloatTensor) 289 | 290 | target_data = np.reshape(CBT, (1, nbr_of_regions, nbr_of_regions, 1)) 291 | target_data = torch.from_numpy(target_data) # convert numpy array to torch tensor 292 | target_data = target_data.type(torch.FloatTensor) 293 | 294 | i = 1 295 | for train, test in kfold.split(source_data): 296 | adversarial_loss = torch.nn.BCELoss() 297 | l1_loss = torch.nn.L1Loss() 298 | trained_model_gen = torch.load('./weight_' + str(i) + 'generator_.model') 299 | generator = Generator() 300 | generator.load_state_dict(trained_model_gen) 301 | 302 | train_data = source_data[train] 303 | test_data = source_data[test] 304 | 305 | generator.to(device) 306 | adversarial_loss.to(device) 307 | l1_loss.to(device) 308 | 309 | X_train_casted_source = [d.to(device) for d in cast_data(train_data, 0)] 310 | X_test_casted_source = [d.to(device) for d in cast_data(test_data, 0)] 311 | data_B = [d.to(device) for d in cast_data(target_data, 0)] 312 | 313 | embedded_train_data = embed(X_train_casted_source) 314 | embedded_test_data = embed(X_test_casted_source) 315 | embedded_CBT = embed(data_B) 316 | 317 | if i == 1: 318 | data_next = linear_features(data_next) 319 | predicted_flat = test_gGAN(data_next, embedded_train_data, embedded_test_data, embedded_CBT) 320 | 321 | plot_MAE(predicted_flat, data_next, test, i) 322 | i = i + 1 323 | 324 | predicted = make_sym_matrix(nbr_of_regions, predicted_flat) 325 | plot_predictions(predicted, i - 1) 326 | 327 | demo() 328 | 329 | -------------------------------------------------------------------------------- /gGAN.py: -------------------------------------------------------------------------------- 1 | """Main function of gGAN for the paper: Foreseeing Brain Graph Evolution Over Time 2 | Using Deep Adversarial Network Normalizer 3 | Details can be found in: (https://arxiv.org/abs/2009.11166) 4 | (1) the original paper . 5 | --------------------------------------------------------------------- 6 | This file contains the implementation of two key steps of our gGAN framework: 7 | netNorm(v, nbr_of_sub, nbr_of_regions) 8 | Inputs: 9 | v: (n × t x t) matrix stacking the source graphs of all subjects 10 | n the total number of subjects 11 | t number of regions 12 | Output: 13 | CBT: (t x t) matrix representing the connectional brain template 14 | 15 | gGAN(sourceGraph, nbr_of_regions, nbr_of_folds, nbr_of_epochs, hyper_param1, CBT) 16 | Inputs: 17 | sourceGraph: (n × t x t) matrix stacking the source graphs of all subjects 18 | n the total number of subjects 19 | t number of regions 20 | CBT: (t x t) matrix stacking the connectional brain template generated by netNorm 21 | 22 | Output: 23 | translatedGraph: (t x t) matrix stacking the graph translated into CBT 24 | 25 | This code has been slightly modified to be compatible across all PyTorch versions. 26 | 27 | (2) Dependencies: please install the following libraries: 28 | - matplotlib 29 | - numpy 30 | - scikitlearn 31 | - pytorch 32 | - pytorch-geometric 33 | - pytorch-scatter 34 | - pytorch-sparse 35 | - scipy 36 | 37 | --------------------------------------------------------------------- 38 | Copyright 2020 (). 39 | Please cite the above paper if you use this code. 40 | All rights reserved. 41 | """ 42 | 43 | 44 | # If you are using Google Colab please uncomment the three following lines. 45 | # !pip install torch_geometric 46 | # !pip install torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html 47 | # !pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html 48 | 49 | 50 | import argparse 51 | import os 52 | import pdb 53 | import numpy as np 54 | import math 55 | import itertools 56 | import torch 57 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout 58 | from sklearn.preprocessing import MinMaxScaler 59 | from sklearn import preprocessing 60 | from torch_geometric.data import Data 61 | from torch.autograd import Variable 62 | import torch.nn.functional as F 63 | import torch.nn as nn 64 | from torch_geometric.nn import NNConv 65 | from torch_geometric.nn import BatchNorm, EdgePooling, TopKPooling, global_add_pool 66 | from sklearn.model_selection import KFold 67 | from sklearn.cluster import KMeans 68 | import matplotlib.pyplot as plt 69 | import scipy.io 70 | import scipy.stats as stats 71 | import random 72 | 73 | import seaborn as sns 74 | 75 | torch.cuda.empty_cache() 76 | torch.cuda.empty_cache() 77 | 78 | # random seed 79 | manualSeed = 1 80 | 81 | np.random.seed(manualSeed) 82 | random.seed(manualSeed) 83 | torch.manual_seed(manualSeed) 84 | 85 | if torch.cuda.is_available(): 86 | device = torch.device('cuda') 87 | print('running on GPU') 88 | # if you are using GPU 89 | torch.cuda.manual_seed(manualSeed) 90 | torch.cuda.manual_seed_all(manualSeed) 91 | 92 | torch.backends.cudnn.enabled = False 93 | torch.backends.cudnn.benchmark = False 94 | torch.backends.cudnn.deterministic = True 95 | 96 | else: 97 | device = torch.device("cpu") 98 | print('running on CPU') 99 | 100 | def netNorm(v, nbr_of_sub, nbr_of_regions): 101 | nbr_of_feat = int((np.square(nbr_of_regions) - nbr_of_regions) / 2) 102 | 103 | def upper_triangular(): 104 | All_subj = np.zeros((nbr_of_sub, nbr_of_feat)) 105 | for j in range(nbr_of_sub): 106 | subj_x = v[j, :, :] 107 | subj_x = np.reshape(subj_x, (nbr_of_regions, nbr_of_regions)) 108 | subj_x = subj_x[np.triu_indices(nbr_of_regions, k=1)] 109 | subj_x = np.reshape(subj_x, (1, nbr_of_feat)) 110 | All_subj[j, :] = subj_x 111 | 112 | return All_subj 113 | 114 | def distances_inter(All_subj): 115 | theta = 0 116 | distance_vector = np.zeros(1) 117 | distance_vector_final = np.zeros(1) 118 | x = All_subj 119 | for i in range(nbr_of_feat): 120 | ROI_i = x[:, i] 121 | for j in range(nbr_of_sub): 122 | subj_j = ROI_i[j:j+1] 123 | 124 | distance_euclidienne_sub_j_sub_k = 0 125 | for k in range(nbr_of_sub): 126 | if k != j: 127 | subj_k = ROI_i[k:k+1] 128 | 129 | distance_euclidienne_sub_j_sub_k = distance_euclidienne_sub_j_sub_k + np.square(subj_k - subj_j) 130 | theta +=1 131 | if j == 0: 132 | distance_vector = np.sqrt(distance_euclidienne_sub_j_sub_k) 133 | else: 134 | distance_vector = np.concatenate((distance_vector, np.sqrt(distance_euclidienne_sub_j_sub_k)), axis=0) 135 | 136 | distance_vector = np.reshape(distance_vector, (nbr_of_sub, 1)) 137 | if i == 0: 138 | distance_vector_final = distance_vector 139 | else: 140 | distance_vector_final = np.concatenate((distance_vector_final, distance_vector), axis=1) 141 | 142 | print(theta) 143 | return distance_vector_final 144 | 145 | 146 | def minimum_distances(distance_vector_final): 147 | x = distance_vector_final 148 | 149 | for i in range(nbr_of_feat): 150 | minimum_sub = x[0, i:i+1] 151 | minimum_sub = float(minimum_sub) 152 | general_minimum = 0 153 | general_minimum = np.array(general_minimum) 154 | for k in range(1, nbr_of_sub): 155 | local_sub = x[k:k+1, i:i+1] 156 | local_sub = float(local_sub) 157 | if local_sub < minimum_sub: 158 | general_minimum = k 159 | general_minimum = np.array(general_minimum) 160 | minimum_sub = local_sub 161 | if i == 0: 162 | final_general_minimum = np.array(general_minimum) 163 | else: 164 | final_general_minimum = np.vstack((final_general_minimum, general_minimum)) 165 | 166 | final_general_minimum = np.transpose(final_general_minimum) 167 | 168 | return final_general_minimum 169 | 170 | def new_tensor(final_general_minimum, All_subj): 171 | y = All_subj 172 | x = final_general_minimum 173 | for i in range(nbr_of_feat): 174 | optimal_subj = x[:, i:i+1] 175 | optimal_subj = np.reshape(optimal_subj, (1)) 176 | optimal_subj = int(optimal_subj) 177 | if i == 0: 178 | final_new_tensor = y[optimal_subj: optimal_subj+1, i:i+1] 179 | else: 180 | final_new_tensor = np.concatenate((final_new_tensor, y[optimal_subj: optimal_subj+1, i:i+1]), axis=1) 181 | 182 | return final_new_tensor 183 | 184 | def make_sym_matrix(nbr_of_regions, feature_vector): 185 | my_matrix = np.zeros([nbr_of_regions, nbr_of_regions], dtype=np.double) 186 | 187 | my_matrix[np.triu_indices(nbr_of_regions, k=1)] = feature_vector 188 | my_matrix = my_matrix + my_matrix.T 189 | my_matrix[np.diag_indices(nbr_of_regions)] = 0 190 | 191 | return my_matrix 192 | 193 | def re_make_tensor(final_new_tensor, nbr_of_regions): 194 | x = final_new_tensor 195 | #x = np.reshape(x, (nbr_of_views, nbr_of_feat)) 196 | 197 | x = make_sym_matrix(nbr_of_regions, x) 198 | x = np.reshape(x, (1, nbr_of_regions, nbr_of_regions)) 199 | 200 | return x 201 | 202 | Upp_trig = upper_triangular() 203 | Dis_int = distances_inter(Upp_trig) 204 | Min_dis = minimum_distances(Dis_int) 205 | New_ten = new_tensor(Min_dis, Upp_trig) 206 | Re_ten = re_make_tensor(New_ten, nbr_of_regions) 207 | Re_ten = np.reshape(Re_ten, (nbr_of_regions, nbr_of_regions)) 208 | np.fill_diagonal(Re_ten, 0) 209 | network = np.array(Re_ten) 210 | return network 211 | 212 | def gGAN(data, nbr_of_regions, nbr_of_epochs, nbr_of_folds, hyper_param1, CBT): 213 | def cast_data(array_of_tensors, version): 214 | version1 = torch.tensor(version, dtype=torch.int) 215 | 216 | N_ROI = array_of_tensors[0].shape[0] 217 | CHANNELS = 1 218 | dataset = [] 219 | edge_index = torch.zeros(2, N_ROI * N_ROI) 220 | edge_attr = torch.zeros(N_ROI * N_ROI, CHANNELS) 221 | x = torch.zeros((N_ROI, N_ROI)) # 35 x 35 222 | y = torch.zeros((1,)) 223 | 224 | counter = 0 225 | for i in range(N_ROI): 226 | for j in range(N_ROI): 227 | edge_index[:, counter] = torch.tensor([i, j]) 228 | counter += 1 229 | for mat in array_of_tensors: #1,35,35,4 230 | 231 | if version1 == 0: 232 | edge_attr = mat.view((nbr_of_regions*nbr_of_regions), 1) 233 | x = mat.view(nbr_of_regions, nbr_of_regions) 234 | edge_index = torch.tensor(edge_index, dtype=torch.long) 235 | edge_attr = torch.tensor(edge_attr, dtype=torch.float) 236 | x = torch.tensor(x, dtype=torch.float) 237 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) 238 | dataset.append(data) 239 | 240 | elif version1 == 1: 241 | edge_attr = torch.randn(N_ROI * N_ROI, CHANNELS) 242 | x = torch.randn(N_ROI, N_ROI) # 35 x 35 243 | edge_index = torch.tensor(edge_index, dtype=torch.long) 244 | edge_attr = torch.tensor(edge_attr, dtype=torch.float) 245 | x = torch.tensor(x, dtype=torch.float) 246 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) 247 | dataset.append(data) 248 | 249 | return dataset 250 | 251 | # ------------------------------------------------------------ 252 | 253 | def plotting_loss(losses_generator, losses_discriminator, epoch): 254 | plt.figure(1) 255 | plt.plot(epoch, losses_generator, 'r-') 256 | plt.plot(epoch, losses_discriminator, 'b-') 257 | plt.legend(['G Loss', 'D Loss']) 258 | plt.xlabel('Epoch') 259 | plt.ylabel('Loss') 260 | plt.savefig('./plot/loss' + str(epoch) + '.png') 261 | 262 | # ------------------------------------------------------------- 263 | 264 | class Generator(nn.Module): 265 | def __init__(self): 266 | super(Generator, self).__init__() 267 | 268 | nn = Sequential(Linear(1, (nbr_of_regions*nbr_of_regions)), ReLU()) 269 | self.conv1 = NNConv(nbr_of_regions, nbr_of_regions, nn, aggr='mean', root_weight=True, bias=True) 270 | self.conv11 = BatchNorm(nbr_of_regions, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 271 | 272 | nn = Sequential(Linear(1, nbr_of_regions), ReLU()) 273 | self.conv2 = NNConv(nbr_of_regions, 1, nn, aggr='mean', root_weight=True, bias=True) 274 | self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 275 | 276 | nn = Sequential(Linear(1, nbr_of_regions), ReLU()) 277 | self.conv3 = NNConv(1, nbr_of_regions, nn, aggr='mean', root_weight=True, bias=True) 278 | self.conv33 = BatchNorm(nbr_of_regions, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 279 | 280 | def forward(self, data): 281 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 282 | 283 | x1 = F.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr))) 284 | x1 = F.dropout(x1, training=self.training) 285 | 286 | x2 = F.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr))) 287 | x2 = F.dropout(x2, training=self.training) 288 | 289 | x3 = torch.cat([F.sigmoid(self.conv33(self.conv3(x2, edge_index, edge_attr))), x1], dim=1) 290 | x4 = x3[:, 0:nbr_of_regions] 291 | x5 = x3[:, nbr_of_regions:2*nbr_of_regions] 292 | 293 | x6 = (x4 + x5) / 2 294 | return x6 295 | 296 | class Discriminator1(torch.nn.Module): 297 | def __init__(self): 298 | super(Discriminator1, self).__init__() 299 | nn = Sequential(Linear(2, (nbr_of_regions*nbr_of_regions)), ReLU()) 300 | self.conv1 = NNConv(nbr_of_regions, nbr_of_regions, nn, aggr='mean', root_weight=True, bias=True) 301 | self.conv11 = BatchNorm(nbr_of_regions, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 302 | 303 | nn = Sequential(Linear(2, nbr_of_regions), ReLU()) 304 | self.conv2 = NNConv(nbr_of_regions, 1, nn, aggr='mean', root_weight=True, bias=True) 305 | self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 306 | 307 | 308 | def forward(self, data, data_to_translate): 309 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 310 | edge_attr_data_to_translate = data_to_translate.edge_attr 311 | 312 | edge_attr_data_to_translate_reshaped = edge_attr_data_to_translate.view(nbr_of_regions*nbr_of_regions, 1) 313 | 314 | gen_input = torch.cat((edge_attr, edge_attr_data_to_translate_reshaped), -1) 315 | x = F.relu(self.conv11(self.conv1(x, edge_index, gen_input))) 316 | x = F.dropout(x, training=self.training) 317 | x = F.relu(self.conv22(self.conv2(x, edge_index, gen_input))) 318 | 319 | return F.sigmoid(x) 320 | 321 | # ---------------------------------------- 322 | # Training 323 | # ---------------------------------------- 324 | 325 | n_fold_counter = 1 326 | plot_loss_g = np.empty((nbr_of_epochs), dtype=float) 327 | plot_loss_d = np.empty((nbr_of_epochs), dtype=float) 328 | 329 | kfold = KFold(n_splits=nbr_of_folds, shuffle=True, random_state=manualSeed) 330 | 331 | source_data = torch.from_numpy(data) # convert numpy array to torch tensor 332 | source_data = source_data.type(torch.FloatTensor) 333 | 334 | target_data = np.reshape(CBT, (1, nbr_of_regions, nbr_of_regions, 1)) 335 | target_data = torch.from_numpy(target_data) # convert numpy array to torch tensor 336 | target_data = target_data.type(torch.FloatTensor) 337 | 338 | for train, test in kfold.split(source_data): 339 | # Loss function 340 | adversarial_loss = torch.nn.BCELoss() 341 | l1_loss = torch.nn.L1Loss() 342 | # Initialize generator and discriminator 343 | generator = Generator() 344 | discriminator1 = Discriminator1() 345 | 346 | generator.to(device) 347 | discriminator1.to(device) 348 | adversarial_loss.to(device) 349 | l1_loss.to(device) 350 | 351 | # Optimizers 352 | optimizer_G = torch.optim.AdamW(generator.parameters(), lr=0.005, betas=(0.5, 0.999)) 353 | optimizer_D = torch.optim.AdamW(discriminator1.parameters(), lr=0.01, betas=(0.5, 0.999)) 354 | 355 | # ------------------------------- select source data and target data ------------------------------- 356 | 357 | train_source, test_source = source_data[train], source_data[test] ## from a specific source view 358 | 359 | # 1: everything random; 0: everything is the matrix in question 360 | 361 | train_casted_source = [d.to(device) for d in cast_data(train_source, 0)] 362 | train_casted_target = [d.to(device) for d in cast_data(target_data, 0)] 363 | 364 | for epoch in range(nbr_of_epochs): 365 | # Train Generator 366 | with torch.autograd.set_detect_anomaly(True): 367 | 368 | losses_generator = [] 369 | losses_discriminator = [] 370 | 371 | for data_A in train_casted_source: 372 | generators_output_ = generator(data_A) # 35 x35 373 | generators_output = generators_output_.view(1, nbr_of_regions, nbr_of_regions, 1).type(torch.FloatTensor) 374 | 375 | generators_output_casted = [d.to(device) for d in cast_data(generators_output, 0)] 376 | for (data_discriminator) in generators_output_casted: 377 | discriminator_output_of_gen = discriminator1(data_discriminator, data_A).to(device) 378 | 379 | g_loss_adversarial = adversarial_loss(discriminator_output_of_gen, torch.ones_like(discriminator_output_of_gen)) 380 | 381 | g_loss_pix2pix = l1_loss(generators_output_, train_casted_target[0].edge_attr.view(nbr_of_regions, nbr_of_regions)) 382 | 383 | g_loss = g_loss_adversarial + (hyper_param1 * g_loss_pix2pix) 384 | losses_generator.append(g_loss) 385 | 386 | discriminator_output_for_real_loss = discriminator1(data_A, train_casted_target[0]) 387 | 388 | real_loss = adversarial_loss(discriminator_output_for_real_loss, 389 | (torch.ones_like(discriminator_output_for_real_loss, requires_grad=False))) 390 | fake_loss = adversarial_loss(discriminator_output_of_gen.detach(), torch.zeros_like(discriminator_output_of_gen)) 391 | 392 | d_loss = (real_loss + fake_loss) / 2 393 | losses_discriminator.append(d_loss) 394 | 395 | optimizer_G.zero_grad() 396 | losses_generator = torch.mean(torch.stack(losses_generator)) 397 | losses_generator.backward(retain_graph=True) 398 | optimizer_G.step() 399 | 400 | optimizer_D.zero_grad() 401 | losses_discriminator = torch.mean(torch.stack(losses_discriminator)) 402 | 403 | losses_discriminator.backward(retain_graph=True) 404 | optimizer_D.step() 405 | 406 | print( 407 | "[Epoch %d/%d] [D loss: %f] [G loss: %f]" 408 | % (epoch, nbr_of_epochs, losses_discriminator, losses_generator)) 409 | 410 | plot_loss_g[epoch] = losses_generator.detach().cpu().clone().numpy() 411 | plot_loss_d[epoch] = losses_discriminator.detach().cpu().clone().numpy() 412 | 413 | torch.save(generator.state_dict(), "./weight_" + str(n_fold_counter) + "generator" + "_" + ".model") 414 | torch.save(discriminator1.state_dict(), "./weight_" + str(n_fold_counter) + "dicriminator" + "_" + ".model") 415 | 416 | interval = range(0, nbr_of_epochs) 417 | plotting_loss(plot_loss_g, plot_loss_d, interval) 418 | n_fold_counter += 1 419 | torch.cuda.empty_cache() 420 | torch.cuda.empty_cache() 421 | 422 | 423 | nbr_of_sub = int(input('Please select the number of subjects: ')) 424 | if nbr_of_sub < 5: 425 | print("You can not give less than 5 to the number of subjects. ") 426 | nbr_of_sub = int(input('Please select the number of subjects: ')) 427 | nbr_of_sub_for_cbt = int(input('Please select the number of subjects to generate the CBT: ')) 428 | nbr_of_regions = int(input('Please select the number of regions: ')) 429 | nbr_of_epochs = int(input('Please select the number of epochs: ')) 430 | nbr_of_folds = int(input('Please select the number of folds: ')) 431 | hyper_param1 = 100 432 | nbr_of_feat = int((np.square(nbr_of_regions) - nbr_of_regions) / 2) 433 | 434 | data = np.random.normal(0.6, 0.3, (nbr_of_sub, nbr_of_regions, nbr_of_regions)) 435 | data = np.abs(data) 436 | independent_data = np.random.normal(0.6, 0.3, (nbr_of_sub_for_cbt, nbr_of_regions, nbr_of_regions)) 437 | independent_data = np.abs(independent_data) 438 | CBT = netNorm(independent_data, nbr_of_sub_for_cbt, nbr_of_regions) 439 | gGAN(data, nbr_of_regions, nbr_of_epochs, nbr_of_folds, hyper_param1, CBT) 440 | --------------------------------------------------------------------------------