├── losses └── Placeholder ├── plots └── Placeholder ├── weights └── Placeholder ├── fig1.png ├── Visualization.png ├── data ├── create_data.py └── .ipynb_checkpoints │ └── Untitled-checkpoint.ipynb ├── code ├── plot.py ├── model.py ├── data_utils.py └── EvoGraphNet.py └── README.md /losses/Placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /plots/Placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /weights/Placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/EvoGraphNet/HEAD/fig1.png -------------------------------------------------------------------------------- /Visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/EvoGraphNet/HEAD/Visualization.png -------------------------------------------------------------------------------- /data/create_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | mean, std = np.random.rand(), np.random.rand() 4 | 5 | for i in range(1, 114): 6 | 7 | # Create adjacency matrices 8 | 9 | t0 = np.abs(np.random.normal(mean, std, (35,35))) % 1.0 10 | mean_s = mean + np.random.rand() % 0.1 11 | std_s = std + np.random.rand() % 0.1 12 | t1 = np.abs(np.random.normal(mean_s, std_s, (35,35))) % 1.0 13 | mean_s = mean + np.random.rand() % 0.1 14 | std_s = std + np.random.rand() % 0.1 15 | t2 = np.abs(np.random.normal(mean_s, std_s, (35,35))) % 1.0 16 | 17 | # Make them symmetric 18 | 19 | t0 = (t0 + t0.T)/2 20 | t1 = (t1 + t1.T)/2 21 | t2 = (t2 + t2.T)/2 22 | 23 | # Clean the diagonals 24 | t0[np.diag_indices_from(t0)] = 0 25 | t1[np.diag_indices_from(t1)] = 0 26 | t2[np.diag_indices_from(t2)] = 0 27 | 28 | # Save them 29 | s = "cortical.lh.ShapeConnectivityTensor_OAS2_" 30 | if i < 10: 31 | s += "0" 32 | s += "00" + str(i) + "_MR1" 33 | 34 | t0_s = s + "_t0.txt" 35 | t1_s = s + "_t1.txt" 36 | t2_s = s + "_t2.txt" 37 | 38 | np.savetxt(t0_s, t0) 39 | np.savetxt(t1_s, t1) 40 | np.savetxt(t2_s, t2) 41 | -------------------------------------------------------------------------------- /code/plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import math 6 | import itertools 7 | import copy 8 | import pickle 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, LeakyReLU 14 | from torch.autograd import Variable 15 | 16 | from sklearn import preprocessing 17 | from sklearn.preprocessing import MinMaxScaler 18 | from sklearn.model_selection import KFold 19 | 20 | from torch_geometric.data import Data, InMemoryDataset, DataLoader 21 | from torch_geometric.nn import NNConv, BatchNorm, EdgePooling, TopKPooling, global_add_pool 22 | 23 | import matplotlib.pyplot as plt 24 | 25 | 26 | def plot(loss, title, losses): 27 | fig = plt.figure() 28 | plt.plot(losses) 29 | plt.xlabel("# epoch") 30 | plt.ylabel(loss) 31 | plt.title(title) 32 | plt.savefig('../plots/' + title + '.png') 33 | plt.close() 34 | 35 | 36 | def plot_matrix(out, fold, sample, epoch, strategy): 37 | fig = plt.figure() 38 | plt.pcolor(abs(out)) 39 | plt.colorbar() 40 | plt.imshow(out) 41 | title = "Generator Output, Epoch = " + str(epoch) + " Fold = " + str(fold) + " Strategy = " + strategy 42 | plt.title(title) 43 | plt.savefig('../plots/' + str(fold) + 'Gen_' + str(sample) + '_' + str(epoch) + '.png') 44 | plt.close() 45 | 46 | 47 | -------------------------------------------------------------------------------- /data/.ipynb_checkpoints/Untitled-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 46, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 47, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "t0 = np.loadtxt(\"cortical.lh.ShapeConnectivityTensor_OAS2_0001_MR1_t0.txt\")\n", 20 | "t1 = np.loadtxt(\"cortical.lh.ShapeConnectivityTensor_OAS2_0001_MR1_t1.txt\")\n", 21 | "t2 = np.loadtxt(\"cortical.lh.ShapeConnectivityTensor_OAS2_0001_MR1_t2.txt\")" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 48, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "def plot_matrix(m):\n", 31 | " plt.matshow(m)\n", 32 | " plt.colorbar()\n", 33 | " plt.show()" 34 | ] 35 | } 36 | ], 37 | "metadata": { 38 | "kernelspec": { 39 | "display_name": "Python 3", 40 | "language": "python", 41 | "name": "python3" 42 | }, 43 | "language_info": { 44 | "codemirror_mode": { 45 | "name": "ipython", 46 | "version": 3 47 | }, 48 | "file_extension": ".py", 49 | "mimetype": "text/x-python", 50 | "name": "python", 51 | "nbconvert_exporter": "python", 52 | "pygments_lexer": "ipython3", 53 | "version": "3.7.7" 54 | } 55 | }, 56 | "nbformat": 4, 57 | "nbformat_minor": 4 58 | } 59 | -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import math 6 | import itertools 7 | import copy 8 | import pickle 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, LeakyReLU 14 | from torch.autograd import Variable 15 | 16 | from sklearn import preprocessing 17 | from sklearn.preprocessing import MinMaxScaler 18 | from sklearn.model_selection import KFold 19 | 20 | from torch_geometric.data import Data, InMemoryDataset, DataLoader 21 | from torch_geometric.nn import NNConv, BatchNorm, EdgePooling, TopKPooling, global_add_pool 22 | 23 | import matplotlib.pyplot as plt 24 | 25 | 26 | class Generator(nn.Module): 27 | def __init__(self): 28 | super(Generator, self).__init__() 29 | 30 | lin = Sequential(Linear(1, 1225), ReLU()) 31 | self.conv1 = NNConv(35, 35, lin, aggr='mean', root_weight=True, bias=True) 32 | self.conv11 = BatchNorm(35, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 33 | 34 | lin = Sequential(Linear(1, 35), ReLU()) 35 | self.conv2 = NNConv(35, 1, lin, aggr='mean', root_weight=True, bias=True) 36 | self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 37 | 38 | lin = Sequential(Linear(1, 35), ReLU()) 39 | self.conv3 = NNConv(1, 35, lin, aggr='mean', root_weight=True, bias=True) 40 | self.conv33 = BatchNorm(35, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 41 | 42 | def forward(self, data): 43 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 44 | 45 | x1 = torch.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr))) 46 | x1 = F.dropout(x1, training=self.training) 47 | #Below 2 lines are the corrections 48 | x1 = (x1 + x1.T) / 2.0 49 | x1.fill_diagonal_(fill_value = 0) 50 | x2 = torch.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr))) 51 | x2 = F.dropout(x2, training=self.training) 52 | 53 | x3 = torch.cat([torch.sigmoid(self.conv33(self.conv3(x2, edge_index, edge_attr))), x1], dim=1) 54 | x4 = x3[:, 0:35] 55 | x5 = x3[:, 35:70] 56 | 57 | x6 = (x4 + x5) / 2 58 | #Below 2 lines are the corrections 59 | x6 = (x6 + x6.T) / 2.0 60 | x6.fill_diagonal_(fill_value = 0) 61 | return x6 62 | 63 | 64 | class Discriminator(torch.nn.Module): 65 | 66 | def __init__(self): 67 | super(Discriminator, self).__init__() 68 | lin = Sequential(Linear(2, 1225), ReLU()) 69 | self.conv1 = NNConv(35, 35, lin, aggr='mean', root_weight=True, bias=True) 70 | self.conv11 = BatchNorm(35, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 71 | 72 | lin = Sequential(Linear(2, 35), ReLU()) 73 | self.conv2 = NNConv(35, 1, lin, aggr='mean', root_weight=True, bias=True) 74 | self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True) 75 | 76 | def forward(self, data, data_to_translate): 77 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 78 | edge_attr_data_to_translate = data_to_translate.edge_attr 79 | 80 | edge_attr_data_to_translate_reshaped = edge_attr_data_to_translate.view(1225, 1) 81 | 82 | gen_input = torch.cat((edge_attr, edge_attr_data_to_translate_reshaped), -1) 83 | x = F.relu(self.conv11(self.conv1(x, edge_index, gen_input))) 84 | x = F.dropout(x, training=self.training) 85 | x = F.relu(self.conv22(self.conv2(x, edge_index, gen_input))) 86 | 87 | return torch.sigmoid(x) 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EvoGraphNet 2 | EvoGraphNet for joint prediction of brain connection evolution, coded up in Python by Uğur Ali Kaplan (uguralikaplan@gmail.com) and Ahmed Nebli (mr.ahmednebli@gmail.com). 3 | 4 | This repository provides the official PyTorch implementation of the following paper: 5 | 6 | ![fig1](fig1.png) 7 | 8 | > **Deep EvoGraphNet Architecture For Time-Dependent Brain Graph Data Synthesis From a Single Timepoint**
9 | > [Ahmed Nebli](https://github.com/ahmednebli)†1,2, [Uğur Ali Kaplan](https://github.com/UgurKap)†1, [Islem Rekik](https://basira-lab.com/)1
10 | > 1BASIRA Lab, Faculty of Computer and Informatics, Istanbul Technical University, Istanbul, Turkey
11 | > 2National School for Computer Science (ENSI), Mannouba, Tunisia
12 | > Equal Contribution
13 | > 14 | > **Abstract:** *Learning how to predict the brain connectome (i.e. graph) development and aging is of paramount importance for charting the future of within-disorder and cross-disorder landscape of brain dysconnectivity evolution. Indeed, predicting the longitudinal (i.e., time-dependent) brain dysconnectivity as it emerges and evolves over time from a single timepoint can help design personalized treatments for disordered patients in a very early stage. Despite its significance, evolution models of the brain graph are largely overlooked in the literature. Here, we propose EvoGraphNet, the first end-to-end geometric deep learning powered graph-generative adversarial network (gGAN) for predicting time-dependent brain graph evolution from a single timepoint. Our EvoGraphNet architecture cascades a set of time-dependent gGANs, where each gGAN communicates its predicted brain graphs at a particular timepoint to train the next gGAN in the cascade at follow-up timepoint. Therefore, we obtain each next predicted timepoint by setting the output of each generator as the input of its successor which enables us to predict a given number of timepoints using only one single timepoint in an end-to-end fashion. At each timepoint, to better align the distribution of the predicted brain graphs with that of the ground-truth graphs, we further integrate an auxiliary Kullback-Leibler divergence loss function. To capture time-dependency between two consecutive observations, we impose an l1 loss to minimize the sparse distance between two serialized brain graphs. A series of benchmarks against variants and ablated versions of our EvoGraphNet showed that we can achieve the lowest brain graph evolution prediction error using a single baseline timepoint.* 15 | 16 | ## Dependencies 17 | * [Python 3.8+](https://www.python.org/) 18 | * [PyTorch 1.5.0+](http://pytorch.org/) 19 | * [PyTorch Geometric 1.4.3+ and Relevant Packages](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) 20 | * [Scikit-learn 0.23.0+](https://scikit-learn.org/stable/) 21 | * [Matplotlib 3.1.3+](https://matplotlib.org/) 22 | * [Numpy 1.18.1+](https://numpy.org/) 23 | 24 | ## Simulating Time-series data 25 | 26 | To simulate longitudinal brain data, you can run the create_data.py code under "data" directory. It will create 113 random samples. 27 | 28 | ```bash 29 | python create_data.py 30 | ``` 31 | 32 | ## Running EvoGraphNet 33 | 34 | You can use the EvoGraphNet.py located under the "code" directory to run the model. To set the parameters, you should provide commandline arguments. 35 | 36 | You can run the program with the following command: 37 | 38 | ```bash 39 | python EvoGraphNet.py --loss LS --epoch 500 --folds 5 40 | ``` 41 | 42 | In this example, we are using Least Squares as adversarial loss and training for 500 epochs in each of the 5 folds. If you want to run the code in the hyperparameters described in the paper, you can run it without any commandline arguments: 43 | 44 | ```bash 45 | python EvoGraphNet.py 46 | ``` 47 | 48 | Other Commandline Arguments: 49 | 50 | --lr_g: Generator learning rate 51 | --lr_d: Discriminator learning rate 52 | --loss: Which adversarial loss to use for training, choices= BCE, LS 53 | --batch: Batch Size 54 | --epoch: How many epochs to train 55 | --folds: How many folds for Cross Validation 56 | --tr_st: Training strategy of GANs. 57 |         same: Train generator and discriminator at the same time 58 |         turns: Alternate training generator and discriminator in each iteration: 59 |         idle: Similar to turns, but wait for more than 1 turns (user can choose how many turns) 60 | --id_e: If training strategy is idle, for how many epochs 61 | --exp: Experiment number for logging purposes 62 | --tp_c: Coefficient of topology loss 63 | --g_c: Coefficient of adversarial loss 64 | --i_c: Coefficient of identity loss 65 | --kl_c: Coefficient of KL loss 66 | --decay: Weight Decay 67 | 68 | You can run the following command to see the default values and reminders for parameters. 69 | 70 | ```bash 71 | python EvoGraphNet.py --help 72 | ``` 73 | ## Example Results 74 | 75 | When given the brain connections data at t0, EvoGraphNet.py will produce two matrices showing brain connections at t1 and t2. In this example, our matrices are 35 x 35. 76 | 77 | ![Visualization](Visualization.png) 78 | 79 | # YouTube videos to install and run the code and understand how EvoGraphNet works 80 | 81 | To install and run EvoGraphNet, check the following YouTube video: 82 | https://youtu.be/eTUeQ15FeRc 83 | 84 | To learn about how EvoGraphNet works, check the following YouTube video: 85 | https://youtu.be/aT---t2OBO0 86 | 87 | # Please cite the following paper when using EvoGraphNet: 88 | 89 | ```latex 90 | @inproceedings{neblikaplanrekik2020, 91 | title={Deep EvoGraphNet Architecture For Time-Dependent Brain Graph Data Synthesis From a Single Timepoint}, 92 | author={Nebli, Ahmed and Kaplan, Ugur Ali and Rekik, Islem}, 93 | booktitle={International Workshop on PRedictive Intelligence In MEdicine}, 94 | year={2020}, 95 | organization={Springer} 96 | } 97 | ``` 98 | 99 | # EvoGraphNet on arXiv 100 | 101 | Link: https://arxiv.org/abs/2009.13217 102 | 103 | # License 104 | Our code is released under MIT License (see LICENSE file for details). 105 | -------------------------------------------------------------------------------- /code/data_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import math 6 | import itertools 7 | import copy 8 | import pickle 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, LeakyReLU 14 | from torch.autograd import Variable 15 | from torch.distributions import normal 16 | 17 | from sklearn import preprocessing 18 | from sklearn.preprocessing import MinMaxScaler 19 | from sklearn.model_selection import KFold 20 | 21 | from torch_geometric.data import Data, InMemoryDataset, DataLoader 22 | from torch_geometric.nn import NNConv, BatchNorm, EdgePooling, TopKPooling, global_add_pool 23 | from torch_geometric.utils import get_laplacian, to_dense_adj 24 | 25 | import matplotlib.pyplot as plt 26 | 27 | 28 | class MRDataset(InMemoryDataset): 29 | 30 | def __init__(self, root, src, dest, h, connectomes=1, subs=1000, transform=None, pre_transform=None): 31 | 32 | """ 33 | src: Input to the model 34 | dest: Target output of the model 35 | h: Load LH or RH data 36 | subs: Maximum number of subjects 37 | 38 | Note: Since we do not reprocess the data if it is already processed, processed files should be 39 | deleted if there is any change in the data we are reading. 40 | """ 41 | 42 | self.src, self.dest, self.h, self.subs, self.connectomes = src, dest, h, subs, connectomes 43 | super(MRDataset, self).__init__(root, transform, pre_transform) 44 | self.data, self.slices = torch.load(self.processed_paths[0]) 45 | 46 | def data_read(self, h="lh", nbr_of_subs=1000, connectomes=1): 47 | 48 | """ 49 | Takes the (maximum) number of subjects and hemisphere we are working on 50 | as arguments, returns t0, t1, t2's of the connectomes for each subject 51 | in a single torch.FloatTensor. 52 | """ 53 | 54 | subs = None # Subjects 55 | 56 | data_path = "../data" 57 | 58 | for i in range(1, nbr_of_subs): 59 | s = data_path + "/cortical." + h.lower() + ".ShapeConnectivityTensor_OAS2_" 60 | if i < 10: 61 | s += "0" 62 | s += "00" + str(i) + "_" 63 | 64 | for mr in ["MR1", "MR2"]: 65 | try: # Sometimes subject we are looking for does not exist 66 | t0 = np.loadtxt(s + mr + "_t0.txt") 67 | t1 = np.loadtxt(s + mr + "_t1.txt") 68 | t2 = np.loadtxt(s + mr + "_t2.txt") 69 | except: 70 | continue 71 | 72 | # Read the connectomes at t0, t1 and t2, then stack them 73 | read_limit = (connectomes * 35) 74 | t_stacked = np.vstack((t0[:read_limit, :], t1[:read_limit, :], t2[:read_limit, :])) 75 | tsr = t_stacked.reshape(3, connectomes * 35, 35) 76 | 77 | if subs is None: # If first subject 78 | subs = tsr 79 | else: 80 | subs = np.vstack((subs, tsr)) 81 | 82 | # Then, reshape to match the shape of the model's expected input shape 83 | # final_views should be a torch tensor or Pytorch Geometric complains 84 | final_views = torch.tensor(np.moveaxis(subs.reshape(-1, 3, (connectomes * 35), 35), 1, -1), dtype=torch.float) 85 | 86 | return final_views 87 | 88 | @property 89 | def processed_file_names(self): 90 | return [ 91 | "data_" + str(self.connectomes) + "_" + self.h.lower() + "_" + str(self.subs) + "_" + str(self.src) + str( 92 | self.dest) + ".pt"] 93 | 94 | def process(self): 95 | 96 | """ 97 | Prepares the data for PyTorch Geometric. 98 | """ 99 | 100 | unprocessed = self.data_read(self.h, self.subs) 101 | num_samples, timestamps = unprocessed.shape[0], unprocessed.shape[-1] 102 | assert 0 <= self.dest <= timestamps 103 | assert 0 <= self.src <= timestamps 104 | 105 | # Turn the data into PyTorch Geometric Graphs 106 | data_list = list() 107 | 108 | for sample in range(num_samples): 109 | x = unprocessed[sample, :, :, self.src] 110 | y = unprocessed[sample, :, :, self.dest] 111 | 112 | edge_index, edge_attr, rows, cols = create_edge_index_attribute(x) 113 | y_edge_index, y_edge_attr, _, _ = create_edge_index_attribute(y) 114 | 115 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 116 | y=y, y_edge_index=y_edge_index, y_edge_attr=y_edge_attr) 117 | 118 | data.num_nodes = rows 119 | data_list.append(data) 120 | 121 | if self.pre_filter is not None: 122 | data_list = [data for data in data_list if self.pre_filter(data)] 123 | 124 | if self.pre_transform is not None: 125 | data_list = [self.pre_transform(data) for data in data_list] 126 | 127 | data, slices = self.collate(data_list) 128 | torch.save((data, slices), self.processed_paths[0]) 129 | 130 | 131 | class MRDataset2(InMemoryDataset): 132 | 133 | def __init__(self, root, h, connectomes=1, subs=1000, transform=None, pre_transform=None): 134 | 135 | """ 136 | src: Input to the model 137 | dest: Target output of the model 138 | h: Load LH or RH data 139 | subs: Maximum number of subjects 140 | 141 | Note: Since we do not reprocess the data if it is already processed, processed files should be 142 | deleted if there is any change in the data we are reading. 143 | """ 144 | 145 | self.h, self.subs, self.connectomes = h, subs, connectomes 146 | super(MRDataset2, self).__init__(root, transform, pre_transform) 147 | self.data, self.slices = torch.load(self.processed_paths[0]) 148 | 149 | def data_read(self, h="lh", nbr_of_subs=1000, connectomes=1): 150 | 151 | """ 152 | Takes the (maximum) number of subjects and hemisphere we are working on 153 | as arguments, returns t0, t1, t2's of the connectomes for each subject 154 | in a single torch.FloatTensor. 155 | """ 156 | 157 | subs = None # Subjects 158 | 159 | data_path = "../data" 160 | 161 | for i in range(1, nbr_of_subs): 162 | s = data_path + "/cortical." + h.lower() + ".ShapeConnectivityTensor_OAS2_" 163 | if i < 10: 164 | s += "0" 165 | s += "00" + str(i) + "_" 166 | 167 | for mr in ["MR1", "MR2"]: 168 | try: # Sometimes subject we are looking for does not exist 169 | t0 = np.loadtxt(s + mr + "_t0.txt") 170 | t1 = np.loadtxt(s + mr + "_t1.txt") 171 | t2 = np.loadtxt(s + mr + "_t2.txt") 172 | except: 173 | continue 174 | 175 | # Read the connectomes at t0, t1 and t2, then stack them 176 | read_limit = (connectomes * 35) 177 | t_stacked = np.vstack((t0[:read_limit, :], t1[:read_limit, :], t2[:read_limit, :])) 178 | tsr = t_stacked.reshape(3, connectomes * 35, 35) 179 | 180 | if subs is None: # If first subject 181 | subs = tsr 182 | else: 183 | subs = np.vstack((subs, tsr)) 184 | 185 | # Then, reshape to match the shape of the model's expected input shape 186 | # final_views should be a torch tensor or Pytorch Geometric complains 187 | final_views = torch.tensor(np.moveaxis(subs.reshape(-1, 3, (connectomes * 35), 35), 1, -1), dtype=torch.float) 188 | 189 | return final_views 190 | 191 | @property 192 | def processed_file_names(self): 193 | return [ 194 | "2data_" + str(self.connectomes) + "_" + self.h.lower() + "_" + str(self.subs) + "_" + ".pt"] 195 | 196 | def process(self): 197 | 198 | """ 199 | Prepares the data for PyTorch Geometric. 200 | """ 201 | 202 | unprocessed = self.data_read(self.h, self.subs) 203 | num_samples, timestamps = unprocessed.shape[0], unprocessed.shape[-1] 204 | 205 | # Turn the data into PyTorch Geometric Graphs 206 | data_list = list() 207 | 208 | for sample in range(num_samples): 209 | x = unprocessed[sample, :, :, 0] 210 | y = unprocessed[sample, :, :, 1] 211 | y2 = unprocessed[sample, :, :, 2] 212 | 213 | edge_index, edge_attr, rows, cols = create_edge_index_attribute(x) 214 | y_edge_index, y_edge_attr, _, _ = create_edge_index_attribute(y) 215 | y2_edge_index, y2_edge_attr, _, _ = create_edge_index_attribute(y2) 216 | y_distr = normal.Normal(y.mean(dim=1), y.std(dim=1)) 217 | y2_distr = normal.Normal(y2.mean(dim=1), y2.std(dim=1)) 218 | y_lap_ei, y_lap_ea = get_laplacian(y_edge_index, y_edge_attr) 219 | y2_lap_ei, y2_lap_ea = get_laplacian(y2_edge_index, y2_edge_attr) 220 | y_lap = to_dense_adj(y_lap_ei, edge_attr=y_lap_ea) 221 | y2_lap = to_dense_adj(y2_lap_ei, edge_attr=y2_lap_ea) 222 | 223 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 224 | y=y, y_edge_index=y_edge_index, y_edge_attr=y_edge_attr, y_distr=y_distr, 225 | y2=y2, y2_edge_index=y2_edge_index, y2_edge_attr=y2_edge_attr, y2_distr=y2_distr, 226 | y_lap=y_lap, y2_lap=y2_lap) 227 | 228 | data.num_nodes = rows 229 | data_list.append(data) 230 | 231 | if self.pre_filter is not None: 232 | data_list = [data for data in data_list if self.pre_filter(data)] 233 | 234 | if self.pre_transform is not None: 235 | data_list = [self.pre_transform(data) for data in data_list] 236 | 237 | data, slices = self.collate(data_list) 238 | torch.save((data, slices), self.processed_paths[0]) 239 | 240 | 241 | def create_edge_index_attribute(adj_matrix): 242 | """ 243 | Given an adjacency matrix, this function creates the edge index and edge attribute matrix 244 | suitable to graph representation in PyTorch Geometric. 245 | """ 246 | 247 | rows, cols = adj_matrix.shape[0], adj_matrix.shape[1] 248 | edge_index = torch.zeros((2, rows * cols), dtype=torch.long) 249 | edge_attr = torch.zeros((rows * cols, 1), dtype=torch.float) 250 | counter = 0 251 | 252 | for src, attrs in enumerate(adj_matrix): 253 | for dest, attr in enumerate(attrs): 254 | edge_index[0][counter], edge_index[1][counter] = src, dest 255 | edge_attr[counter] = attr 256 | counter += 1 257 | 258 | return edge_index, edge_attr, rows, cols 259 | 260 | 261 | def swap(data): 262 | # Swaps the x & y values of the given graph 263 | edge_i, edge_attr, _, _ = create_edge_index_attribute(data.y) 264 | data_s = Data(x=data.y, edge_index=edge_i, edge_attr=edge_attr, y=data.x) 265 | return data_s 266 | 267 | 268 | def cross_val_indices(folds, num_samples, new=False): 269 | """ 270 | Takes the number of inputs and number of folds. 271 | Determines indices to go into validation split in each turn. 272 | Saves the indices on a file for experimental reproducibility and does not overwrite 273 | the already determined indices unless new=True. 274 | """ 275 | 276 | kf = KFold(n_splits=folds, shuffle=True) 277 | train_indices = list() 278 | val_indices = list() 279 | 280 | try: 281 | if new == True: 282 | raise IOError 283 | with open("../data/" + str(folds) + "_" + str(num_samples) + "cv_train", "rb") as f: 284 | train_indices = pickle.load(f) 285 | with open("../data/" + str(folds) + "_" + str(num_samples) + "cv_val", "rb") as f: 286 | val_indices = pickle.load(f) 287 | except IOError: 288 | for tr_index, val_index in kf.split(np.zeros((num_samples, 1))): 289 | train_indices.append(tr_index) 290 | val_indices.append(val_index) 291 | with open("../data/" + str(folds) + "_" + str(num_samples) + "cv_train", "wb") as f: 292 | pickle.dump(train_indices, f) 293 | with open("../data/" + str(folds) + "_" + str(num_samples) + "cv_val", "wb") as f: 294 | pickle.dump(val_indices, f) 295 | 296 | return train_indices, val_indices 297 | -------------------------------------------------------------------------------- /code/EvoGraphNet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import math 6 | import itertools 7 | import copy 8 | import pickle 9 | from sys import exit 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, LeakyReLU 15 | from torch.autograd import Variable 16 | from torch.distributions import normal, kl 17 | 18 | from sklearn import preprocessing 19 | from sklearn.preprocessing import MinMaxScaler 20 | from sklearn.model_selection import KFold 21 | 22 | from torch_geometric.data import Data, InMemoryDataset, DataLoader 23 | from torch_geometric.nn import NNConv, BatchNorm, EdgePooling, TopKPooling, global_add_pool 24 | from torch_geometric.utils import get_laplacian, to_dense_adj 25 | 26 | import matplotlib.pyplot as plt 27 | 28 | from data_utils import MRDataset, create_edge_index_attribute, swap, cross_val_indices, MRDataset2 29 | from model import Generator, Discriminator 30 | from plot import plot, plot_matrix 31 | 32 | torch.manual_seed(0) # To get the same results across experiments 33 | 34 | if torch.cuda.is_available(): 35 | device = torch.device('cuda') 36 | print('running on GPU') 37 | else: 38 | device = torch.device("cpu") 39 | print('running on CPU') 40 | 41 | # Parser 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--lr_g', type=float, default=0.01, help='Generator learning rate') 44 | parser.add_argument('--lr_d', type=float, default=0.0002, help='Discriminator learning rate') 45 | parser.add_argument('--loss', type=str, default='BCE', help='Which loss to use for training', 46 | choices=['BCE', 'LS']) 47 | parser.add_argument('--batch', type=int, default=1, help='Batch Size') 48 | parser.add_argument('--epoch', type=int, default=500, help='How many epochs to train') 49 | parser.add_argument('--folds', type=int, default=3, help='How many folds for CV') 50 | parser.add_argument('--tr_st', type=str, default='same', help='Training strategy', 51 | choices=['same', 'turns', 'idle']) 52 | parser.add_argument('--id_e', type=int, default=2, help='If training strategy is idle, for how many epochs') 53 | parser.add_argument('--exp', type=int, default=0, help='Which experiment are you running') 54 | parser.add_argument('--tp_c', type=float, default=0.0, help='Coefficient of topology loss') 55 | parser.add_argument('--g_c', type=float, default=2.0, help='Coefficient of adversarial loss') 56 | parser.add_argument('--i_c', type=float, default=2.0, help='Coefficient of identity loss') 57 | parser.add_argument('--kl_c', type=float, default=0.001, help='Coefficient of KL loss') 58 | parser.add_argument('--decay', type=float, default=0.0, help='Weight Decay') 59 | opt = parser.parse_args() 60 | 61 | # Datasets 62 | 63 | h_data = MRDataset2("../data", "lh", subs=989) 64 | 65 | # Parameters 66 | 67 | batch_size = opt.batch 68 | lr_G = opt.lr_g 69 | lr_D = opt.lr_d 70 | num_epochs = opt.epoch 71 | folds = opt.folds 72 | 73 | connectomes = 1 74 | train_generator = 1 75 | 76 | # Coefficients for loss 77 | i_coeff = opt.i_c 78 | g_coeff = opt.g_c 79 | kl_coeff = opt.kl_c 80 | tp_coeff = opt.tp_c 81 | 82 | if opt.tr_st != 'idle': 83 | opt.id_e = 0 84 | 85 | # Training 86 | 87 | loss_dict = {"BCE": torch.nn.BCELoss().to(device), 88 | "LS": torch.nn.MSELoss().to(device)} 89 | 90 | 91 | adversarial_loss = loss_dict[opt.loss.upper()] 92 | identity_loss = torch.nn.L1Loss().to(device) # Will be used in training 93 | msel = torch.nn.MSELoss().to(device) 94 | mael = torch.nn.L1Loss().to(device) # Not to be used in training (Measure generator success) 95 | counter_g, counter_d = 0, 0 96 | tp = torch.nn.MSELoss().to(device) # Used for node strength 97 | 98 | train_ind, val_ind = cross_val_indices(folds, len(h_data)) 99 | 100 | # Saving the losses for the future 101 | gen_mae_losses_tr = None 102 | disc_real_losses_tr = None 103 | disc_fake_losses_tr = None 104 | gen_mae_losses_val = None 105 | disc_real_losses_val = None 106 | disc_fake_losses_val = None 107 | gen_mae_losses_tr2 = None 108 | disc_real_losses_tr2 = None 109 | disc_fake_losses_tr2 = None 110 | gen_mae_losses_val2 = None 111 | disc_real_losses_val2 = None 112 | disc_fake_losses_val2 = None 113 | k1_train_s = None 114 | k2_train_s = None 115 | k1_val_s = None 116 | k2_val_s = None 117 | tp1_train_s = None 118 | tp2_train_s = None 119 | tp1_val_s = None 120 | tp2_val_s = None 121 | gan1_train_s = None 122 | gan2_train_s = None 123 | gan1_val_s = None 124 | gan2_val_s = None 125 | 126 | # Cross Validation 127 | for fold in range(folds): 128 | train_set, val_set = h_data[list(train_ind[fold])], h_data[list(val_ind[fold])] 129 | h_data_train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) 130 | h_data_test_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True) 131 | val_step = len(h_data_test_loader) 132 | 133 | for data in h_data_train_loader: # Determine the maximum number of samples in a batch 134 | data_size = data.x.size(0) 135 | break 136 | 137 | # Create generators and discriminators 138 | generator = Generator().to(device) 139 | generator2 = Generator().to(device) 140 | discriminator = Discriminator().to(device) 141 | discriminator2 = Discriminator().to(device) 142 | 143 | optimizer_G = torch.optim.AdamW(generator.parameters(), lr=lr_G, betas=(0.5, 0.999), weight_decay=opt.decay) 144 | optimizer_D = torch.optim.AdamW(discriminator.parameters(), lr=lr_D, betas=(0.5, 0.999), weight_decay=opt.decay) 145 | optimizer_G2 = torch.optim.AdamW(generator2.parameters(), lr=lr_G, betas=(0.5, 0.999), weight_decay=opt.decay) 146 | optimizer_D2 = torch.optim.AdamW(discriminator2.parameters(), lr=lr_D, betas=(0.5, 0.999), weight_decay=opt.decay) 147 | 148 | total_step = len(h_data_train_loader) 149 | real_label = torch.ones((data_size, 1)).to(device) 150 | fake_label = torch.zeros((data_size, 1)).to(device) 151 | 152 | 153 | # Will be used for reporting 154 | real_losses, fake_losses, mse_losses, mae_losses = list(), list(), list(), list() 155 | real_losses_val, fake_losses_val, mse_losses_val, mae_losses_val = list(), list(), list(), list() 156 | 157 | real_losses2, fake_losses2, mse_losses2, mae_losses2 = list(), list(), list(), list() 158 | real_losses_val2, fake_losses_val2, mse_losses_val2, mae_losses_val2 = list(), list(), list(), list() 159 | 160 | k1_losses, k2_losses, k1_losses_val, k2_losses_val = list(), list(), list(), list() 161 | tp_losses_1_tr, tp_losses_1_val, tp_losses_2_tr, tp_losses_2_val = list(), list(), list(), list() 162 | gan_losses_1_tr, gan_losses_1_val, gan_losses_2_tr, gan_losses_2_val = list(), list(), list(), list() 163 | 164 | 165 | for epoch in range(num_epochs): 166 | # Reporting 167 | r, f, d, g, mse_l, mae_l = 0, 0, 0, 0, 0, 0 168 | r_val, f_val, d_val, g_val, mse_l_val, mae_l_val = 0, 0, 0, 0, 0, 0 169 | k1_train, k2_train, k1_val, k2_val = 0.0, 0.0, 0.0, 0.0 170 | r2, f2, d2, g2, mse_l2, mae_l2 = 0, 0, 0, 0, 0, 0 171 | r_val2, f_val2, d_val2, g_val2, mse_l_val2, mae_l_val2 = 0, 0, 0, 0, 0, 0 172 | tp1_tr, tp1_val, tp2_tr, tp2_val = 0.0, 0.0, 0.0, 0.0 173 | gan1_tr, gan1_val, gan2_tr, gan2_val = 0.0, 0.0, 0.0, 0.0 174 | 175 | # Train 176 | generator.train() 177 | discriminator.train() 178 | generator2.train() 179 | discriminator2.train() 180 | for i, data in enumerate(h_data_train_loader): 181 | data = data.to(device) 182 | 183 | optimizer_D.zero_grad() 184 | 185 | # Train the discriminator 186 | # Create fake data 187 | fake_y = generator(data).detach() 188 | edge_i, edge_a, _, _ = create_edge_index_attribute(fake_y) 189 | fake_data = Data(x=fake_y, edge_attr=edge_a, edge_index=edge_i).to(device) 190 | swapped_data = Data(x=data.y, edge_attr=data.y_edge_attr, edge_index=data.y_edge_index).to(device) 191 | 192 | # data: Real source and target 193 | # fake_data: Real source and generated target 194 | real_loss = adversarial_loss(discriminator(swapped_data, data), real_label[:data.x.size(0), :]) 195 | fake_loss = adversarial_loss(discriminator(fake_data, data), fake_label[:data.x.size(0), :]) 196 | loss_D = torch.mean(real_loss + fake_loss) / 2 197 | r += real_loss.item() 198 | f += fake_loss.item() 199 | d += loss_D.item() 200 | 201 | # Depending on the chosen training method, we might update the parameters of the discriminator 202 | if (epoch % 2 == 1 and opt.tr_st == "turns") or opt.tr_st == "same" or counter_d >= opt.id_e: 203 | loss_D.backward(retain_graph=True) 204 | optimizer_D.step() 205 | 206 | # Train the generator 207 | optimizer_G.zero_grad() 208 | 209 | # Adversarial Loss 210 | fake_data.x = generator(data) 211 | gan_loss = torch.mean(adversarial_loss(discriminator(fake_data, data), real_label[:data.x.size(0), :])) 212 | gan1_tr += gan_loss.item() 213 | 214 | # KL Loss 215 | kl_loss = kl.kl_divergence(normal.Normal(fake_data.x.mean(dim=1), fake_data.x.std(dim=1)), 216 | normal.Normal(data.y.mean(dim=1), data.y.std(dim=1))).sum() 217 | 218 | # Topology Loss 219 | tp_loss = tp(fake_data.x.sum(dim=-1), data.y.sum(dim=-1)) 220 | tp1_tr += tp_loss.item() 221 | 222 | # Identity Loss is included in the end 223 | loss_G = i_coeff * identity_loss(generator(swapped_data), data.y) + g_coeff * gan_loss + kl_coeff * kl_loss + tp_coeff * tp_loss 224 | g += loss_G.item() 225 | if (epoch % 2 == 0 and opt.tr_st == "turns") or opt.tr_st == "same" or counter_g < opt.id_e: 226 | loss_G.backward(retain_graph=True) 227 | optimizer_G.step() 228 | k1_train += kl_loss.item() 229 | mse_l += msel(generator(data), data.y).item() 230 | mae_l += mael(generator(data), data.y).item() 231 | 232 | # Training of the second part !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 233 | 234 | optimizer_D2.zero_grad() 235 | 236 | # Train the discriminator2 237 | 238 | # Create fake data for t2 from fake data for t1 239 | fake_data.x = fake_data.x.detach() 240 | fake_y2 = generator2(fake_data).detach() 241 | edge_i, edge_a, _, _ = create_edge_index_attribute(fake_y2) 242 | fake_data2 = Data(x=fake_y2, edge_attr=edge_a, edge_index=edge_i).to(device) 243 | swapped_data2 = Data(x=data.y2, edge_attr=data.y2_edge_attr, edge_index=data.y2_edge_index).to(device) 244 | 245 | # fake_data: Data generated for t1 246 | # fake_data2: Data generated for t2 using generated data for t1 247 | # swapped_data2: Real t2 data 248 | real_loss = adversarial_loss(discriminator2(swapped_data2, fake_data), real_label[:data.x.size(0), :]) 249 | fake_loss = adversarial_loss(discriminator2(fake_data2, fake_data), fake_label[:data.x.size(0), :]) 250 | loss_D = torch.mean(real_loss + fake_loss) / 2 251 | r2 += real_loss.item() 252 | f2 += fake_loss.item() 253 | d2 += loss_D.item() 254 | 255 | if (epoch % 2 == 1 and opt.tr_st == "turns") or opt.tr_st == "same" or counter_d >= opt.id_e: 256 | loss_D.backward(retain_graph=True) 257 | optimizer_D2.step() 258 | 259 | # Train generator2 260 | optimizer_G2.zero_grad() 261 | 262 | # Adversarial Loss 263 | fake_data2.x = generator2(fake_data) 264 | gan_loss = torch.mean(adversarial_loss(discriminator2(fake_data2, fake_data), real_label[:data.x.size(0), :])) 265 | gan2_tr += gan_loss.item() 266 | 267 | # Topology Loss 268 | tp_loss = tp(fake_data2.x.sum(dim=-1), data.y2.sum(dim=-1)) 269 | tp2_tr += tp_loss.item() 270 | 271 | # KL Loss 272 | kl_loss = kl.kl_divergence(normal.Normal(fake_data2.x.mean(dim=1), fake_data2.x.std(dim=1)), 273 | normal.Normal(data.y2.mean(dim=1), data.y2.std(dim=1))).sum() 274 | 275 | # Identity Loss 276 | loss_G = i_coeff * identity_loss(generator(swapped_data2), data.y2) + g_coeff * gan_loss + kl_coeff * kl_loss + tp_coeff * tp_loss 277 | g2 += loss_G.item() 278 | if (epoch % 2 == 0 and opt.tr_st == "turns") or opt.tr_st == "same" or counter_g < opt.id_e: 279 | loss_G.backward(retain_graph=True) 280 | optimizer_G2.step() 281 | 282 | k2_train += kl_loss.item() 283 | mse_l2 += msel(generator2(fake_data), data.y2).item() 284 | mae_l2 += mael(generator2(fake_data), data.y2).item() 285 | 286 | # Validate 287 | generator.eval() 288 | discriminator.eval() 289 | generator2.eval() 290 | discriminator2.eval() 291 | 292 | for i, data in enumerate(h_data_test_loader): 293 | data = data.to(device) 294 | # Train the discriminator 295 | # Create fake data 296 | fake_y = generator(data).detach() 297 | edge_i, edge_a, _, _ = create_edge_index_attribute(fake_y) 298 | fake_data = Data(x=fake_y, edge_attr=edge_a, edge_index=edge_i).to(device) 299 | swapped_data = Data(x=data.y, edge_attr=data.y_edge_attr, edge_index=data.y_edge_index).to(device) 300 | 301 | # data: Real source and target 302 | # fake_data: Real source and generated target 303 | real_loss = adversarial_loss(discriminator(swapped_data, data), real_label[:data.x.size(0), :]) 304 | fake_loss = adversarial_loss(discriminator(fake_data, data), fake_label[:data.x.size(0), :]) 305 | loss_D = torch.mean(real_loss + fake_loss) / 2 306 | r_val += real_loss.item() 307 | f_val += fake_loss.item() 308 | d_val += loss_D.item() 309 | 310 | # Adversarial Loss 311 | fake_data.x = generator(data) 312 | gan_loss = torch.mean(adversarial_loss(discriminator(fake_data, data), real_label[:data.x.size(0), :])) 313 | gan1_val += gan_loss.item() 314 | 315 | # Topology Loss 316 | tp_loss = tp(fake_data.x.sum(dim=-1), data.y.sum(dim=-1)) 317 | tp1_val += tp_loss.item() 318 | 319 | kl_loss = kl.kl_divergence(normal.Normal(fake_data.x.mean(dim=1), fake_data.x.std(dim=1)), 320 | normal.Normal(data.y.mean(dim=1), data.y.std(dim=1))).sum() 321 | 322 | # Identity Loss 323 | 324 | loss_G = i_coeff * identity_loss(generator(swapped_data), data.y) + g_coeff * gan_loss * kl_coeff * kl_loss 325 | g_val += loss_G.item() 326 | mse_l_val += msel(generator(data), data.y).item() 327 | mae_l_val += mael(generator(data), data.y).item() 328 | k1_val += kl_loss.item() 329 | 330 | # Second GAN 331 | 332 | # Create fake data for t2 from fake data for t1 333 | fake_data.x = fake_data.x.detach() 334 | fake_y2 = generator2(fake_data) 335 | edge_i, edge_a, _, _ = create_edge_index_attribute(fake_y2) 336 | fake_data2 = Data(x=fake_y2, edge_attr=edge_a, edge_index=edge_i).to(device) 337 | swapped_data2 = Data(x=data.y2, edge_attr=data.y2_edge_attr, edge_index=data.y2_edge_index).to(device) 338 | 339 | # fake_data: Data generated for t1 340 | # fake_data2: Data generated for t2 using generated data for t1 341 | # swapped_data2: Real t2 data 342 | real_loss = adversarial_loss(discriminator2(swapped_data2, fake_data), real_label[:data.x.size(0), :]) 343 | fake_loss = adversarial_loss(discriminator2(fake_data2, fake_data), fake_label[:data.x.size(0), :]) 344 | loss_D = torch.mean(real_loss + fake_loss) / 2 345 | r_val2 += real_loss.item() 346 | f_val2 += fake_loss.item() 347 | d_val2 += loss_D.item() 348 | 349 | # Adversarial Loss 350 | fake_data2.x = generator2(fake_data) 351 | gan_loss = torch.mean(adversarial_loss(discriminator2(fake_data2, fake_data), real_label[:data.x.size(0), :])) 352 | gan2_val += gan_loss.item() 353 | 354 | # Topology Loss 355 | tp_loss = tp(fake_data2.x.sum(dim=-1), data.y2.sum(dim=-1)) 356 | tp2_val += tp_loss.item() 357 | 358 | # KL Loss 359 | kl_loss = kl.kl_divergence(normal.Normal(fake_data2.x.mean(dim=1), fake_data2.x.std(dim=1)), 360 | normal.Normal(data.y2.mean(dim=1), data.y2.std(dim=1))).sum() 361 | k2_val += kl_loss.item() 362 | 363 | # Identity Loss 364 | loss_G = i_coeff * identity_loss(generator(swapped_data2), data.y2) + g_coeff * gan_loss + kl_coeff * kl_loss 365 | g_val2 += loss_G.item() 366 | mse_l_val2 += msel(generator2(fake_data), data.y2).item() 367 | mae_l_val2 += mael(generator2(fake_data), data.y2).item() 368 | 369 | if opt.tr_st == 'idle': 370 | counter_g += 1 371 | counter_d += 1 372 | if counter_g == 2 * opt.id_e: 373 | counter_g = 0 374 | counter_d = 0 375 | 376 | 377 | print(f'Epoch [{epoch + 1}/{num_epochs}]') 378 | print(f'[Train]: D Loss: {d / total_step:.5f}, G Loss: {g / total_step:.5f} R Loss: {r / total_step:.5f}, F Loss: {f / total_step:.5f}, MSE: {mse_l / total_step:.5f}, MAE: {mae_l / total_step:.5f}') 379 | print(f'[Val]: D Loss: {d_val / val_step:.5f}, G Loss: {g_val / val_step:.5f} R Loss: {r_val / val_step:.5f}, F Loss: {f_val / val_step:.5f}, MSE: {mse_l_val / val_step:.5f}, MAE: {mae_l_val / val_step:.5f}') 380 | print(f'[Train]: D2 Loss: {d2 / total_step:.5f}, G2 Loss: {g2 / total_step:.5f} R2 Loss: {r2 / total_step:.5f}, F2 Loss: {f2 / total_step:.5f}, MSE: {mse_l2 / total_step:.5f}, MAE: {mae_l2 / total_step:.5f}') 381 | print(f'[Val]: D2 Loss: {d_val2 / val_step:.5f}, G2 Loss: {g_val2 / val_step:.5f} R2 Loss: {r_val2 / val_step:.5f}, F2 Loss: {f_val2 / val_step:.5f}, MSE: {mse_l_val2 / val_step:.5f}, MAE: {mae_l_val2 / val_step:.5f}') 382 | 383 | real_losses.append(r / total_step) 384 | fake_losses.append(f / total_step) 385 | mse_losses.append(mse_l / total_step) 386 | mae_losses.append(mae_l / total_step) 387 | real_losses_val.append(r_val / val_step) 388 | fake_losses_val.append(f_val / val_step) 389 | mse_losses_val.append(mse_l_val / val_step) 390 | mae_losses_val.append(mae_l_val / val_step) 391 | real_losses2.append(r2 / total_step) 392 | fake_losses2.append(f2 / total_step) 393 | mse_losses2.append(mse_l2 / total_step) 394 | mae_losses2.append(mae_l2 / total_step) 395 | real_losses_val2.append(r_val2 / val_step) 396 | fake_losses_val2.append(f_val2 / val_step) 397 | mse_losses_val2.append(mse_l_val2 / val_step) 398 | mae_losses_val2.append(mae_l_val2 / val_step) 399 | k1_losses.append(k1_train / total_step) 400 | k2_losses.append(k2_train / total_step) 401 | k1_losses_val.append(k1_val / val_step) 402 | k2_losses_val.append(k2_val / val_step) 403 | tp_losses_1_tr.append(tp1_tr / total_step) 404 | tp_losses_1_val.append(tp1_val / val_step) 405 | tp_losses_2_tr.append(tp2_tr / total_step) 406 | tp_losses_2_val.append(tp2_val / val_step) 407 | gan_losses_1_tr.append(gan1_tr / total_step) 408 | gan_losses_1_val.append(gan1_val / val_step) 409 | gan_losses_2_tr.append(gan2_tr / total_step) 410 | gan_losses_2_val.append(gan2_val / val_step) 411 | 412 | # Plot losses 413 | plot("BCE", "DiscriminatorRealLossTrainSet" + str(fold) + "_exp" + str(opt.exp), real_losses) 414 | plot("BCE", "DiscriminatorRealLossValSet" + str(fold) + "_exp" + str(opt.exp), real_losses_val) 415 | plot("BCE", "DiscriminatorFakeLossTrainSet" + str(fold) + "_exp" + str(opt.exp), fake_losses) 416 | plot("BCE", "DiscriminatorFakeLossValSet" + str(fold) + "_exp" + str(opt.exp), fake_losses_val) 417 | plot("MSE", "GeneratorMSELossTrainSet" + str(fold) + "_exp" + str(opt.exp), mse_losses) 418 | plot("MSE", "GeneratorMSELossValSet" + str(fold) + "_exp" + str(opt.exp), mse_losses_val) 419 | plot("MAE", "GeneratorMAELossTrainSet" + str(fold) + "_exp" + str(opt.exp), mae_losses) 420 | plot("MAE", "GeneratorMAELossValSet" + str(fold) + "_exp" + str(opt.exp), mae_losses_val) 421 | plot("BCE", "Discriminator2RealLossTrainSet" + str(fold) + "_exp" + str(opt.exp), real_losses2) 422 | plot("BCE", "Discriminator2RealLossValSet" + str(fold) + "_exp" + str(opt.exp), real_losses_val2) 423 | plot("BCE", "Discriminator2FakeLossTrainSet" + str(fold) + "_exp" + str(opt.exp), fake_losses2) 424 | plot("BCE", "Discriminator2FakeLossValSet" + str(fold) + "_exp" + str(opt.exp), fake_losses_val2) 425 | plot("MSE", "Generator2MSELossTrainSet" + str(fold) + "_exp" + str(opt.exp), mse_losses2) 426 | plot("MSE", "Generator2MSELossValSet" + str(fold) + "_exp" + str(opt.exp), mse_losses_val2) 427 | plot("MAE", "Generator2MAELossTrainSet" + str(fold) + "_exp" + str(opt.exp), mae_losses2) 428 | plot("MAE", "Generator2MAELossValSet" + str(fold) + "_exp" + str(opt.exp), mae_losses_val2) 429 | plot("KL Loss", "KL_Loss_1_TrainSet" + str(fold) + "_exp" + str(opt.exp), k1_losses) 430 | plot("KL Loss", "KL_Loss_1_ValSet" + str(fold) + "_exp" + str(opt.exp), k1_losses_val) 431 | plot("KL Loss", "KL_Loss_2_TrainSet" + str(fold) + "_exp" + str(opt.exp), k2_losses) 432 | plot("KL Loss", "KL_Loss_2_ValSet" + str(fold) + "_exp" + str(opt.exp), k2_losses_val) 433 | plot("TP Loss", "TP_Loss_1_TrainSet" + str(fold) + "_exp" + str(opt.exp), tp_losses_1_tr) 434 | plot("TP Loss", "TP_Loss_1_ValSet" + str(fold) + "_exp" + str(opt.exp), tp_losses_1_val) 435 | plot("TP Loss", "TP_Loss_2_TrainSet" + str(fold) + "_exp" + str(opt.exp), tp_losses_2_tr) 436 | plot("TP Loss", "TP_Loss_2_ValSet" + str(fold) + "_exp" + str(opt.exp), tp_losses_2_val) 437 | plot("BCE", "GAN_Loss_1_TrainSet" + str(fold) + "_exp" + str(opt.exp), gan_losses_1_tr) 438 | plot("BCE", "GAN_Loss_1_ValSet" + str(fold) + "_exp" + str(opt.exp), gan_losses_1_val) 439 | plot("BCE", "GAN_Loss_2_TrainSet" + str(fold) + "_exp" + str(opt.exp), gan_losses_2_tr) 440 | plot("BCE", "GAN_Loss_2_ValSet" + str(fold) + "_exp" + str(opt.exp), gan_losses_2_val) 441 | 442 | # Save the losses 443 | if gen_mae_losses_tr is None: 444 | gen_mae_losses_tr = mae_losses 445 | disc_real_losses_tr = real_losses 446 | disc_fake_losses_tr = fake_losses 447 | gen_mae_losses_val = mae_losses_val 448 | disc_real_losses_val = real_losses_val 449 | disc_fake_losses_val = fake_losses_val 450 | gen_mae_losses_tr2 = mae_losses2 451 | disc_real_losses_tr2 = real_losses2 452 | disc_fake_losses_tr2 = fake_losses2 453 | gen_mae_losses_val2 = mae_losses_val2 454 | disc_real_losses_val2 = real_losses_val2 455 | disc_fake_losses_val2 = fake_losses_val2 456 | k1_train_s = k1_losses 457 | k2_train_s = k2_losses 458 | k1_val_s = k1_losses_val 459 | k2_val_s = k2_losses_val 460 | tp1_train_s = tp_losses_1_tr 461 | tp2_train_s = tp_losses_2_tr 462 | tp1_val_s = tp_losses_1_val 463 | tp2_val_s = tp_losses_2_val 464 | gan1_train_s = gan_losses_1_tr 465 | gan2_train_s = gan_losses_2_tr 466 | gan1_val_s = gan_losses_1_val 467 | gan2_val_s = gan_losses_2_val 468 | else: 469 | gen_mae_losses_tr = np.vstack([gen_mae_losses_tr, mae_losses]) 470 | disc_real_losses_tr = np.vstack([disc_real_losses_tr, real_losses]) 471 | disc_fake_losses_tr = np.vstack([disc_fake_losses_tr, fake_losses]) 472 | gen_mae_losses_val = np.vstack([gen_mae_losses_val, mae_losses_val]) 473 | disc_real_losses_val = np.vstack([disc_real_losses_val, real_losses_val]) 474 | disc_fake_losses_val = np.vstack([disc_fake_losses_val, fake_losses_val]) 475 | gen_mae_losses_tr2 = np.vstack([gen_mae_losses_tr2, mae_losses2]) 476 | disc_real_losses_tr2 = np.vstack([disc_real_losses_tr2, real_losses2]) 477 | disc_fake_losses_tr2 = np.vstack([disc_fake_losses_tr2, fake_losses2]) 478 | gen_mae_losses_val2 = np.vstack([gen_mae_losses_val2, mae_losses_val2]) 479 | disc_real_losses_val2 = np.vstack([disc_real_losses_val2, real_losses_val2]) 480 | disc_fake_losses_val2 = np.vstack([disc_fake_losses_val2, fake_losses_val2]) 481 | k1_train_s = np.vstack([k1_train_s, k1_losses]) 482 | k2_train_s = np.vstack([k2_train_s, k2_losses]) 483 | k1_val_s = np.vstack([k1_val_s, k1_losses_val]) 484 | k2_val_s = np.vstack([k2_val_s, k2_losses_val]) 485 | tp1_train_s = np.vstack([tp1_train_s, tp_losses_1_tr]) 486 | tp2_train_s = np.vstack([tp2_train_s, tp_losses_2_tr]) 487 | tp1_val_s = np.vstack([tp1_val_s, tp_losses_1_val]) 488 | tp2_val_s = np.vstack([tp2_val_s, tp_losses_2_val]) 489 | gan1_train_s = np.vstack([gan1_train_s, gan_losses_1_tr]) 490 | gan2_train_s = np.vstack([gan2_train_s, gan_losses_2_tr]) 491 | gan1_val_s = np.vstack([gan1_val_s, gan_losses_1_val]) 492 | gan2_val_s = np.vstack([gan2_val_s, gan_losses_2_val]) 493 | 494 | # Save the models 495 | torch.save(generator.state_dict(), "../weights/generator_" + str(fold) + "_" + str(epoch) + "_" + str(opt.exp)) 496 | torch.save(discriminator.state_dict(), "../weights/discriminator_" + str(fold) + "_" + str(epoch) + "_" + str(opt.exp)) 497 | torch.save(generator2.state_dict(), 498 | "../weights/generator2_" + str(fold) + "_" + str(epoch) + "_" + str(opt.exp)) 499 | torch.save(discriminator2.state_dict(), 500 | "../weights/discriminator2_" + str(fold) + "_" + str(epoch) + "_" + str(opt.exp)) 501 | 502 | del generator 503 | del discriminator 504 | 505 | del generator2 506 | del discriminator2 507 | 508 | # Save losses 509 | with open("../losses/G_TrainLoss_exp_" + str(opt.exp), "wb") as f: 510 | pickle.dump(gen_mae_losses_tr, f) 511 | with open("../losses/G_ValLoss_exp_" + str(opt.exp), "wb") as f: 512 | pickle.dump(gen_mae_losses_val, f) 513 | with open("../losses/D_TrainRealLoss_exp_" + str(opt.exp), "wb") as f: 514 | pickle.dump(disc_real_losses_tr, f) 515 | with open("../losses/D_TrainFakeLoss_exp_" + str(opt.exp), "wb") as f: 516 | pickle.dump(disc_fake_losses_tr, f) 517 | with open("../losses/D_ValRealLoss_exp_" + str(opt.exp), "wb") as f: 518 | pickle.dump(disc_real_losses_val, f) 519 | with open("../losses/D_ValFakeLoss_exp_" + str(opt.exp), "wb") as f: 520 | pickle.dump(disc_fake_losses_val, f) 521 | with open("../losses/G2_TrainLoss_exp_" + str(opt.exp), "wb") as f: 522 | pickle.dump(gen_mae_losses_tr2, f) 523 | with open("../losses/G2_ValLoss_exp_" + str(opt.exp), "wb") as f: 524 | pickle.dump(gen_mae_losses_val2, f) 525 | with open("../losses/D2_TrainRealLoss_exp_" + str(opt.exp), "wb") as f: 526 | pickle.dump(disc_real_losses_tr2, f) 527 | with open("../losses/D2_TrainFakeLoss_exp_" + str(opt.exp), "wb") as f: 528 | pickle.dump(disc_fake_losses_tr2, f) 529 | with open("../losses/D2_ValRealLoss_exp_" + str(opt.exp), "wb") as f: 530 | pickle.dump(disc_real_losses_val2, f) 531 | with open("../losses/D2_ValFakeLoss_exp_" + str(opt.exp), "wb") as f: 532 | pickle.dump(disc_fake_losses_val2, f) 533 | with open("../losses/GenTotal_Train_exp_" + str(opt.exp), "wb") as f: 534 | pickle.dump(gen_mae_losses_tr + gen_mae_losses_tr2, f) 535 | with open("../losses/GenTotal_Val_exp_" + str(opt.exp), "wb") as f: 536 | pickle.dump(gen_mae_losses_val + gen_mae_losses_val2, f) 537 | with open("../losses/K1_TrainLoss_exp_" + str(opt.exp), "wb") as f: 538 | pickle.dump(k1_train_s, f) 539 | with open("../losses/K1_ValLoss_exp_" + str(opt.exp), "wb") as f: 540 | pickle.dump(k2_train_s, f) 541 | with open("../losses/K2_TrainLoss_exp_" + str(opt.exp), "wb") as f: 542 | pickle.dump(k1_val_s, f) 543 | with open("../losses/K2_ValLoss_exp_" + str(opt.exp), "wb") as f: 544 | pickle.dump(k2_val_s, f) 545 | with open("../losses/TP1_TrainLoss_exp_" + str(opt.exp), "wb") as f: 546 | pickle.dump(tp1_train_s, f) 547 | with open("../losses/TP1_ValLoss_exp_" + str(opt.exp), "wb") as f: 548 | pickle.dump(tp2_train_s, f) 549 | with open("../losses/TP2_TrainLoss_exp_" + str(opt.exp), "wb") as f: 550 | pickle.dump(tp1_val_s, f) 551 | with open("../losses/TP2_ValLoss_exp_" + str(opt.exp), "wb") as f: 552 | pickle.dump(tp2_val_s, f) 553 | with open("../losses/GAN1_TrainLoss_exp_" + str(opt.exp), "wb") as f: 554 | pickle.dump(gan1_train_s, f) 555 | with open("../losses/GAN1_ValLoss_exp_" + str(opt.exp), "wb") as f: 556 | pickle.dump(gan2_train_s, f) 557 | with open("../losses/GAN2_TrainLoss_exp_" + str(opt.exp), "wb") as f: 558 | pickle.dump(gan1_val_s, f) 559 | with open("../losses/GAN2_ValLoss_exp_" + str(opt.exp), "wb") as f: 560 | pickle.dump(gan2_val_s, f) 561 | 562 | print(f"Training Complete for experiment {opt.exp}!") 563 | 564 | --------------------------------------------------------------------------------