├── losses
└── Placeholder
├── plots
└── Placeholder
├── weights
└── Placeholder
├── fig1.png
├── Visualization.png
├── data
├── create_data.py
└── .ipynb_checkpoints
│ └── Untitled-checkpoint.ipynb
├── code
├── plot.py
├── model.py
├── data_utils.py
└── EvoGraphNet.py
└── README.md
/losses/Placeholder:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/plots/Placeholder:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/weights/Placeholder:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basiralab/EvoGraphNet/HEAD/fig1.png
--------------------------------------------------------------------------------
/Visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basiralab/EvoGraphNet/HEAD/Visualization.png
--------------------------------------------------------------------------------
/data/create_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | mean, std = np.random.rand(), np.random.rand()
4 |
5 | for i in range(1, 114):
6 |
7 | # Create adjacency matrices
8 |
9 | t0 = np.abs(np.random.normal(mean, std, (35,35))) % 1.0
10 | mean_s = mean + np.random.rand() % 0.1
11 | std_s = std + np.random.rand() % 0.1
12 | t1 = np.abs(np.random.normal(mean_s, std_s, (35,35))) % 1.0
13 | mean_s = mean + np.random.rand() % 0.1
14 | std_s = std + np.random.rand() % 0.1
15 | t2 = np.abs(np.random.normal(mean_s, std_s, (35,35))) % 1.0
16 |
17 | # Make them symmetric
18 |
19 | t0 = (t0 + t0.T)/2
20 | t1 = (t1 + t1.T)/2
21 | t2 = (t2 + t2.T)/2
22 |
23 | # Clean the diagonals
24 | t0[np.diag_indices_from(t0)] = 0
25 | t1[np.diag_indices_from(t1)] = 0
26 | t2[np.diag_indices_from(t2)] = 0
27 |
28 | # Save them
29 | s = "cortical.lh.ShapeConnectivityTensor_OAS2_"
30 | if i < 10:
31 | s += "0"
32 | s += "00" + str(i) + "_MR1"
33 |
34 | t0_s = s + "_t0.txt"
35 | t1_s = s + "_t1.txt"
36 | t2_s = s + "_t2.txt"
37 |
38 | np.savetxt(t0_s, t0)
39 | np.savetxt(t1_s, t1)
40 | np.savetxt(t2_s, t2)
41 |
--------------------------------------------------------------------------------
/code/plot.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import numpy as np
5 | import math
6 | import itertools
7 | import copy
8 | import pickle
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, LeakyReLU
14 | from torch.autograd import Variable
15 |
16 | from sklearn import preprocessing
17 | from sklearn.preprocessing import MinMaxScaler
18 | from sklearn.model_selection import KFold
19 |
20 | from torch_geometric.data import Data, InMemoryDataset, DataLoader
21 | from torch_geometric.nn import NNConv, BatchNorm, EdgePooling, TopKPooling, global_add_pool
22 |
23 | import matplotlib.pyplot as plt
24 |
25 |
26 | def plot(loss, title, losses):
27 | fig = plt.figure()
28 | plt.plot(losses)
29 | plt.xlabel("# epoch")
30 | plt.ylabel(loss)
31 | plt.title(title)
32 | plt.savefig('../plots/' + title + '.png')
33 | plt.close()
34 |
35 |
36 | def plot_matrix(out, fold, sample, epoch, strategy):
37 | fig = plt.figure()
38 | plt.pcolor(abs(out))
39 | plt.colorbar()
40 | plt.imshow(out)
41 | title = "Generator Output, Epoch = " + str(epoch) + " Fold = " + str(fold) + " Strategy = " + strategy
42 | plt.title(title)
43 | plt.savefig('../plots/' + str(fold) + 'Gen_' + str(sample) + '_' + str(epoch) + '.png')
44 | plt.close()
45 |
46 |
47 |
--------------------------------------------------------------------------------
/data/.ipynb_checkpoints/Untitled-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 46,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import matplotlib.pyplot as plt"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 47,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "t0 = np.loadtxt(\"cortical.lh.ShapeConnectivityTensor_OAS2_0001_MR1_t0.txt\")\n",
20 | "t1 = np.loadtxt(\"cortical.lh.ShapeConnectivityTensor_OAS2_0001_MR1_t1.txt\")\n",
21 | "t2 = np.loadtxt(\"cortical.lh.ShapeConnectivityTensor_OAS2_0001_MR1_t2.txt\")"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 48,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "def plot_matrix(m):\n",
31 | " plt.matshow(m)\n",
32 | " plt.colorbar()\n",
33 | " plt.show()"
34 | ]
35 | }
36 | ],
37 | "metadata": {
38 | "kernelspec": {
39 | "display_name": "Python 3",
40 | "language": "python",
41 | "name": "python3"
42 | },
43 | "language_info": {
44 | "codemirror_mode": {
45 | "name": "ipython",
46 | "version": 3
47 | },
48 | "file_extension": ".py",
49 | "mimetype": "text/x-python",
50 | "name": "python",
51 | "nbconvert_exporter": "python",
52 | "pygments_lexer": "ipython3",
53 | "version": "3.7.7"
54 | }
55 | },
56 | "nbformat": 4,
57 | "nbformat_minor": 4
58 | }
59 |
--------------------------------------------------------------------------------
/code/model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import numpy as np
5 | import math
6 | import itertools
7 | import copy
8 | import pickle
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, LeakyReLU
14 | from torch.autograd import Variable
15 |
16 | from sklearn import preprocessing
17 | from sklearn.preprocessing import MinMaxScaler
18 | from sklearn.model_selection import KFold
19 |
20 | from torch_geometric.data import Data, InMemoryDataset, DataLoader
21 | from torch_geometric.nn import NNConv, BatchNorm, EdgePooling, TopKPooling, global_add_pool
22 |
23 | import matplotlib.pyplot as plt
24 |
25 |
26 | class Generator(nn.Module):
27 | def __init__(self):
28 | super(Generator, self).__init__()
29 |
30 | lin = Sequential(Linear(1, 1225), ReLU())
31 | self.conv1 = NNConv(35, 35, lin, aggr='mean', root_weight=True, bias=True)
32 | self.conv11 = BatchNorm(35, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
33 |
34 | lin = Sequential(Linear(1, 35), ReLU())
35 | self.conv2 = NNConv(35, 1, lin, aggr='mean', root_weight=True, bias=True)
36 | self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
37 |
38 | lin = Sequential(Linear(1, 35), ReLU())
39 | self.conv3 = NNConv(1, 35, lin, aggr='mean', root_weight=True, bias=True)
40 | self.conv33 = BatchNorm(35, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
41 |
42 | def forward(self, data):
43 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
44 |
45 | x1 = torch.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr)))
46 | x1 = F.dropout(x1, training=self.training)
47 | #Below 2 lines are the corrections
48 | x1 = (x1 + x1.T) / 2.0
49 | x1.fill_diagonal_(fill_value = 0)
50 | x2 = torch.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr)))
51 | x2 = F.dropout(x2, training=self.training)
52 |
53 | x3 = torch.cat([torch.sigmoid(self.conv33(self.conv3(x2, edge_index, edge_attr))), x1], dim=1)
54 | x4 = x3[:, 0:35]
55 | x5 = x3[:, 35:70]
56 |
57 | x6 = (x4 + x5) / 2
58 | #Below 2 lines are the corrections
59 | x6 = (x6 + x6.T) / 2.0
60 | x6.fill_diagonal_(fill_value = 0)
61 | return x6
62 |
63 |
64 | class Discriminator(torch.nn.Module):
65 |
66 | def __init__(self):
67 | super(Discriminator, self).__init__()
68 | lin = Sequential(Linear(2, 1225), ReLU())
69 | self.conv1 = NNConv(35, 35, lin, aggr='mean', root_weight=True, bias=True)
70 | self.conv11 = BatchNorm(35, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
71 |
72 | lin = Sequential(Linear(2, 35), ReLU())
73 | self.conv2 = NNConv(35, 1, lin, aggr='mean', root_weight=True, bias=True)
74 | self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
75 |
76 | def forward(self, data, data_to_translate):
77 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
78 | edge_attr_data_to_translate = data_to_translate.edge_attr
79 |
80 | edge_attr_data_to_translate_reshaped = edge_attr_data_to_translate.view(1225, 1)
81 |
82 | gen_input = torch.cat((edge_attr, edge_attr_data_to_translate_reshaped), -1)
83 | x = F.relu(self.conv11(self.conv1(x, edge_index, gen_input)))
84 | x = F.dropout(x, training=self.training)
85 | x = F.relu(self.conv22(self.conv2(x, edge_index, gen_input)))
86 |
87 | return torch.sigmoid(x)
88 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EvoGraphNet
2 | EvoGraphNet for joint prediction of brain connection evolution, coded up in Python by Uğur Ali Kaplan (uguralikaplan@gmail.com) and Ahmed Nebli (mr.ahmednebli@gmail.com).
3 |
4 | This repository provides the official PyTorch implementation of the following paper:
5 |
6 | 
7 |
8 | > **Deep EvoGraphNet Architecture For Time-Dependent Brain Graph Data Synthesis From a Single Timepoint**
9 | > [Ahmed Nebli](https://github.com/ahmednebli)†1,2, [Uğur Ali Kaplan](https://github.com/UgurKap)†1, [Islem Rekik](https://basira-lab.com/)1
10 | > 1BASIRA Lab, Faculty of Computer and Informatics, Istanbul Technical University, Istanbul, Turkey
11 | > 2National School for Computer Science (ENSI), Mannouba, Tunisia
12 | > †Equal Contribution
13 | >
14 | > **Abstract:** *Learning how to predict the brain connectome (i.e. graph) development and aging is of paramount importance for charting the future of within-disorder and cross-disorder landscape of brain dysconnectivity evolution. Indeed, predicting the longitudinal (i.e., time-dependent) brain dysconnectivity as it emerges and evolves over time from a single timepoint can help design personalized treatments for disordered patients in a very early stage. Despite its significance, evolution models of the brain graph are largely overlooked in the literature. Here, we propose EvoGraphNet, the first end-to-end geometric deep learning powered graph-generative adversarial network (gGAN) for predicting time-dependent brain graph evolution from a single timepoint. Our EvoGraphNet architecture cascades a set of time-dependent gGANs, where each gGAN communicates its predicted brain graphs at a particular timepoint to train the next gGAN in the cascade at follow-up timepoint. Therefore, we obtain each next predicted timepoint by setting the output of each generator as the input of its successor which enables us to predict a given number of timepoints using only one single timepoint in an end-to-end fashion. At each timepoint, to better align the distribution of the predicted brain graphs with that of the ground-truth graphs, we further integrate an auxiliary Kullback-Leibler divergence loss function. To capture time-dependency between two consecutive observations, we impose an l1 loss to minimize the sparse distance between two serialized brain graphs. A series of benchmarks against variants and ablated versions of our EvoGraphNet showed that we can achieve the lowest brain graph evolution prediction error using a single baseline timepoint.*
15 |
16 | ## Dependencies
17 | * [Python 3.8+](https://www.python.org/)
18 | * [PyTorch 1.5.0+](http://pytorch.org/)
19 | * [PyTorch Geometric 1.4.3+ and Relevant Packages](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)
20 | * [Scikit-learn 0.23.0+](https://scikit-learn.org/stable/)
21 | * [Matplotlib 3.1.3+](https://matplotlib.org/)
22 | * [Numpy 1.18.1+](https://numpy.org/)
23 |
24 | ## Simulating Time-series data
25 |
26 | To simulate longitudinal brain data, you can run the create_data.py code under "data" directory. It will create 113 random samples.
27 |
28 | ```bash
29 | python create_data.py
30 | ```
31 |
32 | ## Running EvoGraphNet
33 |
34 | You can use the EvoGraphNet.py located under the "code" directory to run the model. To set the parameters, you should provide commandline arguments.
35 |
36 | You can run the program with the following command:
37 |
38 | ```bash
39 | python EvoGraphNet.py --loss LS --epoch 500 --folds 5
40 | ```
41 |
42 | In this example, we are using Least Squares as adversarial loss and training for 500 epochs in each of the 5 folds. If you want to run the code in the hyperparameters described in the paper, you can run it without any commandline arguments:
43 |
44 | ```bash
45 | python EvoGraphNet.py
46 | ```
47 |
48 | Other Commandline Arguments:
49 |
50 | --lr_g: Generator learning rate
51 | --lr_d: Discriminator learning rate
52 | --loss: Which adversarial loss to use for training, choices= BCE, LS
53 | --batch: Batch Size
54 | --epoch: How many epochs to train
55 | --folds: How many folds for Cross Validation
56 | --tr_st: Training strategy of GANs.
57 | same: Train generator and discriminator at the same time
58 | turns: Alternate training generator and discriminator in each iteration:
59 | idle: Similar to turns, but wait for more than 1 turns (user can choose how many turns)
60 | --id_e: If training strategy is idle, for how many epochs
61 | --exp: Experiment number for logging purposes
62 | --tp_c: Coefficient of topology loss
63 | --g_c: Coefficient of adversarial loss
64 | --i_c: Coefficient of identity loss
65 | --kl_c: Coefficient of KL loss
66 | --decay: Weight Decay
67 |
68 | You can run the following command to see the default values and reminders for parameters.
69 |
70 | ```bash
71 | python EvoGraphNet.py --help
72 | ```
73 | ## Example Results
74 |
75 | When given the brain connections data at t0, EvoGraphNet.py will produce two matrices showing brain connections at t1 and t2. In this example, our matrices are 35 x 35.
76 |
77 | 
78 |
79 | # YouTube videos to install and run the code and understand how EvoGraphNet works
80 |
81 | To install and run EvoGraphNet, check the following YouTube video:
82 | https://youtu.be/eTUeQ15FeRc
83 |
84 | To learn about how EvoGraphNet works, check the following YouTube video:
85 | https://youtu.be/aT---t2OBO0
86 |
87 | # Please cite the following paper when using EvoGraphNet:
88 |
89 | ```latex
90 | @inproceedings{neblikaplanrekik2020,
91 | title={Deep EvoGraphNet Architecture For Time-Dependent Brain Graph Data Synthesis From a Single Timepoint},
92 | author={Nebli, Ahmed and Kaplan, Ugur Ali and Rekik, Islem},
93 | booktitle={International Workshop on PRedictive Intelligence In MEdicine},
94 | year={2020},
95 | organization={Springer}
96 | }
97 | ```
98 |
99 | # EvoGraphNet on arXiv
100 |
101 | Link: https://arxiv.org/abs/2009.13217
102 |
103 | # License
104 | Our code is released under MIT License (see LICENSE file for details).
105 |
--------------------------------------------------------------------------------
/code/data_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import numpy as np
5 | import math
6 | import itertools
7 | import copy
8 | import pickle
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, LeakyReLU
14 | from torch.autograd import Variable
15 | from torch.distributions import normal
16 |
17 | from sklearn import preprocessing
18 | from sklearn.preprocessing import MinMaxScaler
19 | from sklearn.model_selection import KFold
20 |
21 | from torch_geometric.data import Data, InMemoryDataset, DataLoader
22 | from torch_geometric.nn import NNConv, BatchNorm, EdgePooling, TopKPooling, global_add_pool
23 | from torch_geometric.utils import get_laplacian, to_dense_adj
24 |
25 | import matplotlib.pyplot as plt
26 |
27 |
28 | class MRDataset(InMemoryDataset):
29 |
30 | def __init__(self, root, src, dest, h, connectomes=1, subs=1000, transform=None, pre_transform=None):
31 |
32 | """
33 | src: Input to the model
34 | dest: Target output of the model
35 | h: Load LH or RH data
36 | subs: Maximum number of subjects
37 |
38 | Note: Since we do not reprocess the data if it is already processed, processed files should be
39 | deleted if there is any change in the data we are reading.
40 | """
41 |
42 | self.src, self.dest, self.h, self.subs, self.connectomes = src, dest, h, subs, connectomes
43 | super(MRDataset, self).__init__(root, transform, pre_transform)
44 | self.data, self.slices = torch.load(self.processed_paths[0])
45 |
46 | def data_read(self, h="lh", nbr_of_subs=1000, connectomes=1):
47 |
48 | """
49 | Takes the (maximum) number of subjects and hemisphere we are working on
50 | as arguments, returns t0, t1, t2's of the connectomes for each subject
51 | in a single torch.FloatTensor.
52 | """
53 |
54 | subs = None # Subjects
55 |
56 | data_path = "../data"
57 |
58 | for i in range(1, nbr_of_subs):
59 | s = data_path + "/cortical." + h.lower() + ".ShapeConnectivityTensor_OAS2_"
60 | if i < 10:
61 | s += "0"
62 | s += "00" + str(i) + "_"
63 |
64 | for mr in ["MR1", "MR2"]:
65 | try: # Sometimes subject we are looking for does not exist
66 | t0 = np.loadtxt(s + mr + "_t0.txt")
67 | t1 = np.loadtxt(s + mr + "_t1.txt")
68 | t2 = np.loadtxt(s + mr + "_t2.txt")
69 | except:
70 | continue
71 |
72 | # Read the connectomes at t0, t1 and t2, then stack them
73 | read_limit = (connectomes * 35)
74 | t_stacked = np.vstack((t0[:read_limit, :], t1[:read_limit, :], t2[:read_limit, :]))
75 | tsr = t_stacked.reshape(3, connectomes * 35, 35)
76 |
77 | if subs is None: # If first subject
78 | subs = tsr
79 | else:
80 | subs = np.vstack((subs, tsr))
81 |
82 | # Then, reshape to match the shape of the model's expected input shape
83 | # final_views should be a torch tensor or Pytorch Geometric complains
84 | final_views = torch.tensor(np.moveaxis(subs.reshape(-1, 3, (connectomes * 35), 35), 1, -1), dtype=torch.float)
85 |
86 | return final_views
87 |
88 | @property
89 | def processed_file_names(self):
90 | return [
91 | "data_" + str(self.connectomes) + "_" + self.h.lower() + "_" + str(self.subs) + "_" + str(self.src) + str(
92 | self.dest) + ".pt"]
93 |
94 | def process(self):
95 |
96 | """
97 | Prepares the data for PyTorch Geometric.
98 | """
99 |
100 | unprocessed = self.data_read(self.h, self.subs)
101 | num_samples, timestamps = unprocessed.shape[0], unprocessed.shape[-1]
102 | assert 0 <= self.dest <= timestamps
103 | assert 0 <= self.src <= timestamps
104 |
105 | # Turn the data into PyTorch Geometric Graphs
106 | data_list = list()
107 |
108 | for sample in range(num_samples):
109 | x = unprocessed[sample, :, :, self.src]
110 | y = unprocessed[sample, :, :, self.dest]
111 |
112 | edge_index, edge_attr, rows, cols = create_edge_index_attribute(x)
113 | y_edge_index, y_edge_attr, _, _ = create_edge_index_attribute(y)
114 |
115 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
116 | y=y, y_edge_index=y_edge_index, y_edge_attr=y_edge_attr)
117 |
118 | data.num_nodes = rows
119 | data_list.append(data)
120 |
121 | if self.pre_filter is not None:
122 | data_list = [data for data in data_list if self.pre_filter(data)]
123 |
124 | if self.pre_transform is not None:
125 | data_list = [self.pre_transform(data) for data in data_list]
126 |
127 | data, slices = self.collate(data_list)
128 | torch.save((data, slices), self.processed_paths[0])
129 |
130 |
131 | class MRDataset2(InMemoryDataset):
132 |
133 | def __init__(self, root, h, connectomes=1, subs=1000, transform=None, pre_transform=None):
134 |
135 | """
136 | src: Input to the model
137 | dest: Target output of the model
138 | h: Load LH or RH data
139 | subs: Maximum number of subjects
140 |
141 | Note: Since we do not reprocess the data if it is already processed, processed files should be
142 | deleted if there is any change in the data we are reading.
143 | """
144 |
145 | self.h, self.subs, self.connectomes = h, subs, connectomes
146 | super(MRDataset2, self).__init__(root, transform, pre_transform)
147 | self.data, self.slices = torch.load(self.processed_paths[0])
148 |
149 | def data_read(self, h="lh", nbr_of_subs=1000, connectomes=1):
150 |
151 | """
152 | Takes the (maximum) number of subjects and hemisphere we are working on
153 | as arguments, returns t0, t1, t2's of the connectomes for each subject
154 | in a single torch.FloatTensor.
155 | """
156 |
157 | subs = None # Subjects
158 |
159 | data_path = "../data"
160 |
161 | for i in range(1, nbr_of_subs):
162 | s = data_path + "/cortical." + h.lower() + ".ShapeConnectivityTensor_OAS2_"
163 | if i < 10:
164 | s += "0"
165 | s += "00" + str(i) + "_"
166 |
167 | for mr in ["MR1", "MR2"]:
168 | try: # Sometimes subject we are looking for does not exist
169 | t0 = np.loadtxt(s + mr + "_t0.txt")
170 | t1 = np.loadtxt(s + mr + "_t1.txt")
171 | t2 = np.loadtxt(s + mr + "_t2.txt")
172 | except:
173 | continue
174 |
175 | # Read the connectomes at t0, t1 and t2, then stack them
176 | read_limit = (connectomes * 35)
177 | t_stacked = np.vstack((t0[:read_limit, :], t1[:read_limit, :], t2[:read_limit, :]))
178 | tsr = t_stacked.reshape(3, connectomes * 35, 35)
179 |
180 | if subs is None: # If first subject
181 | subs = tsr
182 | else:
183 | subs = np.vstack((subs, tsr))
184 |
185 | # Then, reshape to match the shape of the model's expected input shape
186 | # final_views should be a torch tensor or Pytorch Geometric complains
187 | final_views = torch.tensor(np.moveaxis(subs.reshape(-1, 3, (connectomes * 35), 35), 1, -1), dtype=torch.float)
188 |
189 | return final_views
190 |
191 | @property
192 | def processed_file_names(self):
193 | return [
194 | "2data_" + str(self.connectomes) + "_" + self.h.lower() + "_" + str(self.subs) + "_" + ".pt"]
195 |
196 | def process(self):
197 |
198 | """
199 | Prepares the data for PyTorch Geometric.
200 | """
201 |
202 | unprocessed = self.data_read(self.h, self.subs)
203 | num_samples, timestamps = unprocessed.shape[0], unprocessed.shape[-1]
204 |
205 | # Turn the data into PyTorch Geometric Graphs
206 | data_list = list()
207 |
208 | for sample in range(num_samples):
209 | x = unprocessed[sample, :, :, 0]
210 | y = unprocessed[sample, :, :, 1]
211 | y2 = unprocessed[sample, :, :, 2]
212 |
213 | edge_index, edge_attr, rows, cols = create_edge_index_attribute(x)
214 | y_edge_index, y_edge_attr, _, _ = create_edge_index_attribute(y)
215 | y2_edge_index, y2_edge_attr, _, _ = create_edge_index_attribute(y2)
216 | y_distr = normal.Normal(y.mean(dim=1), y.std(dim=1))
217 | y2_distr = normal.Normal(y2.mean(dim=1), y2.std(dim=1))
218 | y_lap_ei, y_lap_ea = get_laplacian(y_edge_index, y_edge_attr)
219 | y2_lap_ei, y2_lap_ea = get_laplacian(y2_edge_index, y2_edge_attr)
220 | y_lap = to_dense_adj(y_lap_ei, edge_attr=y_lap_ea)
221 | y2_lap = to_dense_adj(y2_lap_ei, edge_attr=y2_lap_ea)
222 |
223 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
224 | y=y, y_edge_index=y_edge_index, y_edge_attr=y_edge_attr, y_distr=y_distr,
225 | y2=y2, y2_edge_index=y2_edge_index, y2_edge_attr=y2_edge_attr, y2_distr=y2_distr,
226 | y_lap=y_lap, y2_lap=y2_lap)
227 |
228 | data.num_nodes = rows
229 | data_list.append(data)
230 |
231 | if self.pre_filter is not None:
232 | data_list = [data for data in data_list if self.pre_filter(data)]
233 |
234 | if self.pre_transform is not None:
235 | data_list = [self.pre_transform(data) for data in data_list]
236 |
237 | data, slices = self.collate(data_list)
238 | torch.save((data, slices), self.processed_paths[0])
239 |
240 |
241 | def create_edge_index_attribute(adj_matrix):
242 | """
243 | Given an adjacency matrix, this function creates the edge index and edge attribute matrix
244 | suitable to graph representation in PyTorch Geometric.
245 | """
246 |
247 | rows, cols = adj_matrix.shape[0], adj_matrix.shape[1]
248 | edge_index = torch.zeros((2, rows * cols), dtype=torch.long)
249 | edge_attr = torch.zeros((rows * cols, 1), dtype=torch.float)
250 | counter = 0
251 |
252 | for src, attrs in enumerate(adj_matrix):
253 | for dest, attr in enumerate(attrs):
254 | edge_index[0][counter], edge_index[1][counter] = src, dest
255 | edge_attr[counter] = attr
256 | counter += 1
257 |
258 | return edge_index, edge_attr, rows, cols
259 |
260 |
261 | def swap(data):
262 | # Swaps the x & y values of the given graph
263 | edge_i, edge_attr, _, _ = create_edge_index_attribute(data.y)
264 | data_s = Data(x=data.y, edge_index=edge_i, edge_attr=edge_attr, y=data.x)
265 | return data_s
266 |
267 |
268 | def cross_val_indices(folds, num_samples, new=False):
269 | """
270 | Takes the number of inputs and number of folds.
271 | Determines indices to go into validation split in each turn.
272 | Saves the indices on a file for experimental reproducibility and does not overwrite
273 | the already determined indices unless new=True.
274 | """
275 |
276 | kf = KFold(n_splits=folds, shuffle=True)
277 | train_indices = list()
278 | val_indices = list()
279 |
280 | try:
281 | if new == True:
282 | raise IOError
283 | with open("../data/" + str(folds) + "_" + str(num_samples) + "cv_train", "rb") as f:
284 | train_indices = pickle.load(f)
285 | with open("../data/" + str(folds) + "_" + str(num_samples) + "cv_val", "rb") as f:
286 | val_indices = pickle.load(f)
287 | except IOError:
288 | for tr_index, val_index in kf.split(np.zeros((num_samples, 1))):
289 | train_indices.append(tr_index)
290 | val_indices.append(val_index)
291 | with open("../data/" + str(folds) + "_" + str(num_samples) + "cv_train", "wb") as f:
292 | pickle.dump(train_indices, f)
293 | with open("../data/" + str(folds) + "_" + str(num_samples) + "cv_val", "wb") as f:
294 | pickle.dump(val_indices, f)
295 |
296 | return train_indices, val_indices
297 |
--------------------------------------------------------------------------------
/code/EvoGraphNet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import numpy as np
5 | import math
6 | import itertools
7 | import copy
8 | import pickle
9 | from sys import exit
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, LeakyReLU
15 | from torch.autograd import Variable
16 | from torch.distributions import normal, kl
17 |
18 | from sklearn import preprocessing
19 | from sklearn.preprocessing import MinMaxScaler
20 | from sklearn.model_selection import KFold
21 |
22 | from torch_geometric.data import Data, InMemoryDataset, DataLoader
23 | from torch_geometric.nn import NNConv, BatchNorm, EdgePooling, TopKPooling, global_add_pool
24 | from torch_geometric.utils import get_laplacian, to_dense_adj
25 |
26 | import matplotlib.pyplot as plt
27 |
28 | from data_utils import MRDataset, create_edge_index_attribute, swap, cross_val_indices, MRDataset2
29 | from model import Generator, Discriminator
30 | from plot import plot, plot_matrix
31 |
32 | torch.manual_seed(0) # To get the same results across experiments
33 |
34 | if torch.cuda.is_available():
35 | device = torch.device('cuda')
36 | print('running on GPU')
37 | else:
38 | device = torch.device("cpu")
39 | print('running on CPU')
40 |
41 | # Parser
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('--lr_g', type=float, default=0.01, help='Generator learning rate')
44 | parser.add_argument('--lr_d', type=float, default=0.0002, help='Discriminator learning rate')
45 | parser.add_argument('--loss', type=str, default='BCE', help='Which loss to use for training',
46 | choices=['BCE', 'LS'])
47 | parser.add_argument('--batch', type=int, default=1, help='Batch Size')
48 | parser.add_argument('--epoch', type=int, default=500, help='How many epochs to train')
49 | parser.add_argument('--folds', type=int, default=3, help='How many folds for CV')
50 | parser.add_argument('--tr_st', type=str, default='same', help='Training strategy',
51 | choices=['same', 'turns', 'idle'])
52 | parser.add_argument('--id_e', type=int, default=2, help='If training strategy is idle, for how many epochs')
53 | parser.add_argument('--exp', type=int, default=0, help='Which experiment are you running')
54 | parser.add_argument('--tp_c', type=float, default=0.0, help='Coefficient of topology loss')
55 | parser.add_argument('--g_c', type=float, default=2.0, help='Coefficient of adversarial loss')
56 | parser.add_argument('--i_c', type=float, default=2.0, help='Coefficient of identity loss')
57 | parser.add_argument('--kl_c', type=float, default=0.001, help='Coefficient of KL loss')
58 | parser.add_argument('--decay', type=float, default=0.0, help='Weight Decay')
59 | opt = parser.parse_args()
60 |
61 | # Datasets
62 |
63 | h_data = MRDataset2("../data", "lh", subs=989)
64 |
65 | # Parameters
66 |
67 | batch_size = opt.batch
68 | lr_G = opt.lr_g
69 | lr_D = opt.lr_d
70 | num_epochs = opt.epoch
71 | folds = opt.folds
72 |
73 | connectomes = 1
74 | train_generator = 1
75 |
76 | # Coefficients for loss
77 | i_coeff = opt.i_c
78 | g_coeff = opt.g_c
79 | kl_coeff = opt.kl_c
80 | tp_coeff = opt.tp_c
81 |
82 | if opt.tr_st != 'idle':
83 | opt.id_e = 0
84 |
85 | # Training
86 |
87 | loss_dict = {"BCE": torch.nn.BCELoss().to(device),
88 | "LS": torch.nn.MSELoss().to(device)}
89 |
90 |
91 | adversarial_loss = loss_dict[opt.loss.upper()]
92 | identity_loss = torch.nn.L1Loss().to(device) # Will be used in training
93 | msel = torch.nn.MSELoss().to(device)
94 | mael = torch.nn.L1Loss().to(device) # Not to be used in training (Measure generator success)
95 | counter_g, counter_d = 0, 0
96 | tp = torch.nn.MSELoss().to(device) # Used for node strength
97 |
98 | train_ind, val_ind = cross_val_indices(folds, len(h_data))
99 |
100 | # Saving the losses for the future
101 | gen_mae_losses_tr = None
102 | disc_real_losses_tr = None
103 | disc_fake_losses_tr = None
104 | gen_mae_losses_val = None
105 | disc_real_losses_val = None
106 | disc_fake_losses_val = None
107 | gen_mae_losses_tr2 = None
108 | disc_real_losses_tr2 = None
109 | disc_fake_losses_tr2 = None
110 | gen_mae_losses_val2 = None
111 | disc_real_losses_val2 = None
112 | disc_fake_losses_val2 = None
113 | k1_train_s = None
114 | k2_train_s = None
115 | k1_val_s = None
116 | k2_val_s = None
117 | tp1_train_s = None
118 | tp2_train_s = None
119 | tp1_val_s = None
120 | tp2_val_s = None
121 | gan1_train_s = None
122 | gan2_train_s = None
123 | gan1_val_s = None
124 | gan2_val_s = None
125 |
126 | # Cross Validation
127 | for fold in range(folds):
128 | train_set, val_set = h_data[list(train_ind[fold])], h_data[list(val_ind[fold])]
129 | h_data_train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
130 | h_data_test_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)
131 | val_step = len(h_data_test_loader)
132 |
133 | for data in h_data_train_loader: # Determine the maximum number of samples in a batch
134 | data_size = data.x.size(0)
135 | break
136 |
137 | # Create generators and discriminators
138 | generator = Generator().to(device)
139 | generator2 = Generator().to(device)
140 | discriminator = Discriminator().to(device)
141 | discriminator2 = Discriminator().to(device)
142 |
143 | optimizer_G = torch.optim.AdamW(generator.parameters(), lr=lr_G, betas=(0.5, 0.999), weight_decay=opt.decay)
144 | optimizer_D = torch.optim.AdamW(discriminator.parameters(), lr=lr_D, betas=(0.5, 0.999), weight_decay=opt.decay)
145 | optimizer_G2 = torch.optim.AdamW(generator2.parameters(), lr=lr_G, betas=(0.5, 0.999), weight_decay=opt.decay)
146 | optimizer_D2 = torch.optim.AdamW(discriminator2.parameters(), lr=lr_D, betas=(0.5, 0.999), weight_decay=opt.decay)
147 |
148 | total_step = len(h_data_train_loader)
149 | real_label = torch.ones((data_size, 1)).to(device)
150 | fake_label = torch.zeros((data_size, 1)).to(device)
151 |
152 |
153 | # Will be used for reporting
154 | real_losses, fake_losses, mse_losses, mae_losses = list(), list(), list(), list()
155 | real_losses_val, fake_losses_val, mse_losses_val, mae_losses_val = list(), list(), list(), list()
156 |
157 | real_losses2, fake_losses2, mse_losses2, mae_losses2 = list(), list(), list(), list()
158 | real_losses_val2, fake_losses_val2, mse_losses_val2, mae_losses_val2 = list(), list(), list(), list()
159 |
160 | k1_losses, k2_losses, k1_losses_val, k2_losses_val = list(), list(), list(), list()
161 | tp_losses_1_tr, tp_losses_1_val, tp_losses_2_tr, tp_losses_2_val = list(), list(), list(), list()
162 | gan_losses_1_tr, gan_losses_1_val, gan_losses_2_tr, gan_losses_2_val = list(), list(), list(), list()
163 |
164 |
165 | for epoch in range(num_epochs):
166 | # Reporting
167 | r, f, d, g, mse_l, mae_l = 0, 0, 0, 0, 0, 0
168 | r_val, f_val, d_val, g_val, mse_l_val, mae_l_val = 0, 0, 0, 0, 0, 0
169 | k1_train, k2_train, k1_val, k2_val = 0.0, 0.0, 0.0, 0.0
170 | r2, f2, d2, g2, mse_l2, mae_l2 = 0, 0, 0, 0, 0, 0
171 | r_val2, f_val2, d_val2, g_val2, mse_l_val2, mae_l_val2 = 0, 0, 0, 0, 0, 0
172 | tp1_tr, tp1_val, tp2_tr, tp2_val = 0.0, 0.0, 0.0, 0.0
173 | gan1_tr, gan1_val, gan2_tr, gan2_val = 0.0, 0.0, 0.0, 0.0
174 |
175 | # Train
176 | generator.train()
177 | discriminator.train()
178 | generator2.train()
179 | discriminator2.train()
180 | for i, data in enumerate(h_data_train_loader):
181 | data = data.to(device)
182 |
183 | optimizer_D.zero_grad()
184 |
185 | # Train the discriminator
186 | # Create fake data
187 | fake_y = generator(data).detach()
188 | edge_i, edge_a, _, _ = create_edge_index_attribute(fake_y)
189 | fake_data = Data(x=fake_y, edge_attr=edge_a, edge_index=edge_i).to(device)
190 | swapped_data = Data(x=data.y, edge_attr=data.y_edge_attr, edge_index=data.y_edge_index).to(device)
191 |
192 | # data: Real source and target
193 | # fake_data: Real source and generated target
194 | real_loss = adversarial_loss(discriminator(swapped_data, data), real_label[:data.x.size(0), :])
195 | fake_loss = adversarial_loss(discriminator(fake_data, data), fake_label[:data.x.size(0), :])
196 | loss_D = torch.mean(real_loss + fake_loss) / 2
197 | r += real_loss.item()
198 | f += fake_loss.item()
199 | d += loss_D.item()
200 |
201 | # Depending on the chosen training method, we might update the parameters of the discriminator
202 | if (epoch % 2 == 1 and opt.tr_st == "turns") or opt.tr_st == "same" or counter_d >= opt.id_e:
203 | loss_D.backward(retain_graph=True)
204 | optimizer_D.step()
205 |
206 | # Train the generator
207 | optimizer_G.zero_grad()
208 |
209 | # Adversarial Loss
210 | fake_data.x = generator(data)
211 | gan_loss = torch.mean(adversarial_loss(discriminator(fake_data, data), real_label[:data.x.size(0), :]))
212 | gan1_tr += gan_loss.item()
213 |
214 | # KL Loss
215 | kl_loss = kl.kl_divergence(normal.Normal(fake_data.x.mean(dim=1), fake_data.x.std(dim=1)),
216 | normal.Normal(data.y.mean(dim=1), data.y.std(dim=1))).sum()
217 |
218 | # Topology Loss
219 | tp_loss = tp(fake_data.x.sum(dim=-1), data.y.sum(dim=-1))
220 | tp1_tr += tp_loss.item()
221 |
222 | # Identity Loss is included in the end
223 | loss_G = i_coeff * identity_loss(generator(swapped_data), data.y) + g_coeff * gan_loss + kl_coeff * kl_loss + tp_coeff * tp_loss
224 | g += loss_G.item()
225 | if (epoch % 2 == 0 and opt.tr_st == "turns") or opt.tr_st == "same" or counter_g < opt.id_e:
226 | loss_G.backward(retain_graph=True)
227 | optimizer_G.step()
228 | k1_train += kl_loss.item()
229 | mse_l += msel(generator(data), data.y).item()
230 | mae_l += mael(generator(data), data.y).item()
231 |
232 | # Training of the second part !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
233 |
234 | optimizer_D2.zero_grad()
235 |
236 | # Train the discriminator2
237 |
238 | # Create fake data for t2 from fake data for t1
239 | fake_data.x = fake_data.x.detach()
240 | fake_y2 = generator2(fake_data).detach()
241 | edge_i, edge_a, _, _ = create_edge_index_attribute(fake_y2)
242 | fake_data2 = Data(x=fake_y2, edge_attr=edge_a, edge_index=edge_i).to(device)
243 | swapped_data2 = Data(x=data.y2, edge_attr=data.y2_edge_attr, edge_index=data.y2_edge_index).to(device)
244 |
245 | # fake_data: Data generated for t1
246 | # fake_data2: Data generated for t2 using generated data for t1
247 | # swapped_data2: Real t2 data
248 | real_loss = adversarial_loss(discriminator2(swapped_data2, fake_data), real_label[:data.x.size(0), :])
249 | fake_loss = adversarial_loss(discriminator2(fake_data2, fake_data), fake_label[:data.x.size(0), :])
250 | loss_D = torch.mean(real_loss + fake_loss) / 2
251 | r2 += real_loss.item()
252 | f2 += fake_loss.item()
253 | d2 += loss_D.item()
254 |
255 | if (epoch % 2 == 1 and opt.tr_st == "turns") or opt.tr_st == "same" or counter_d >= opt.id_e:
256 | loss_D.backward(retain_graph=True)
257 | optimizer_D2.step()
258 |
259 | # Train generator2
260 | optimizer_G2.zero_grad()
261 |
262 | # Adversarial Loss
263 | fake_data2.x = generator2(fake_data)
264 | gan_loss = torch.mean(adversarial_loss(discriminator2(fake_data2, fake_data), real_label[:data.x.size(0), :]))
265 | gan2_tr += gan_loss.item()
266 |
267 | # Topology Loss
268 | tp_loss = tp(fake_data2.x.sum(dim=-1), data.y2.sum(dim=-1))
269 | tp2_tr += tp_loss.item()
270 |
271 | # KL Loss
272 | kl_loss = kl.kl_divergence(normal.Normal(fake_data2.x.mean(dim=1), fake_data2.x.std(dim=1)),
273 | normal.Normal(data.y2.mean(dim=1), data.y2.std(dim=1))).sum()
274 |
275 | # Identity Loss
276 | loss_G = i_coeff * identity_loss(generator(swapped_data2), data.y2) + g_coeff * gan_loss + kl_coeff * kl_loss + tp_coeff * tp_loss
277 | g2 += loss_G.item()
278 | if (epoch % 2 == 0 and opt.tr_st == "turns") or opt.tr_st == "same" or counter_g < opt.id_e:
279 | loss_G.backward(retain_graph=True)
280 | optimizer_G2.step()
281 |
282 | k2_train += kl_loss.item()
283 | mse_l2 += msel(generator2(fake_data), data.y2).item()
284 | mae_l2 += mael(generator2(fake_data), data.y2).item()
285 |
286 | # Validate
287 | generator.eval()
288 | discriminator.eval()
289 | generator2.eval()
290 | discriminator2.eval()
291 |
292 | for i, data in enumerate(h_data_test_loader):
293 | data = data.to(device)
294 | # Train the discriminator
295 | # Create fake data
296 | fake_y = generator(data).detach()
297 | edge_i, edge_a, _, _ = create_edge_index_attribute(fake_y)
298 | fake_data = Data(x=fake_y, edge_attr=edge_a, edge_index=edge_i).to(device)
299 | swapped_data = Data(x=data.y, edge_attr=data.y_edge_attr, edge_index=data.y_edge_index).to(device)
300 |
301 | # data: Real source and target
302 | # fake_data: Real source and generated target
303 | real_loss = adversarial_loss(discriminator(swapped_data, data), real_label[:data.x.size(0), :])
304 | fake_loss = adversarial_loss(discriminator(fake_data, data), fake_label[:data.x.size(0), :])
305 | loss_D = torch.mean(real_loss + fake_loss) / 2
306 | r_val += real_loss.item()
307 | f_val += fake_loss.item()
308 | d_val += loss_D.item()
309 |
310 | # Adversarial Loss
311 | fake_data.x = generator(data)
312 | gan_loss = torch.mean(adversarial_loss(discriminator(fake_data, data), real_label[:data.x.size(0), :]))
313 | gan1_val += gan_loss.item()
314 |
315 | # Topology Loss
316 | tp_loss = tp(fake_data.x.sum(dim=-1), data.y.sum(dim=-1))
317 | tp1_val += tp_loss.item()
318 |
319 | kl_loss = kl.kl_divergence(normal.Normal(fake_data.x.mean(dim=1), fake_data.x.std(dim=1)),
320 | normal.Normal(data.y.mean(dim=1), data.y.std(dim=1))).sum()
321 |
322 | # Identity Loss
323 |
324 | loss_G = i_coeff * identity_loss(generator(swapped_data), data.y) + g_coeff * gan_loss * kl_coeff * kl_loss
325 | g_val += loss_G.item()
326 | mse_l_val += msel(generator(data), data.y).item()
327 | mae_l_val += mael(generator(data), data.y).item()
328 | k1_val += kl_loss.item()
329 |
330 | # Second GAN
331 |
332 | # Create fake data for t2 from fake data for t1
333 | fake_data.x = fake_data.x.detach()
334 | fake_y2 = generator2(fake_data)
335 | edge_i, edge_a, _, _ = create_edge_index_attribute(fake_y2)
336 | fake_data2 = Data(x=fake_y2, edge_attr=edge_a, edge_index=edge_i).to(device)
337 | swapped_data2 = Data(x=data.y2, edge_attr=data.y2_edge_attr, edge_index=data.y2_edge_index).to(device)
338 |
339 | # fake_data: Data generated for t1
340 | # fake_data2: Data generated for t2 using generated data for t1
341 | # swapped_data2: Real t2 data
342 | real_loss = adversarial_loss(discriminator2(swapped_data2, fake_data), real_label[:data.x.size(0), :])
343 | fake_loss = adversarial_loss(discriminator2(fake_data2, fake_data), fake_label[:data.x.size(0), :])
344 | loss_D = torch.mean(real_loss + fake_loss) / 2
345 | r_val2 += real_loss.item()
346 | f_val2 += fake_loss.item()
347 | d_val2 += loss_D.item()
348 |
349 | # Adversarial Loss
350 | fake_data2.x = generator2(fake_data)
351 | gan_loss = torch.mean(adversarial_loss(discriminator2(fake_data2, fake_data), real_label[:data.x.size(0), :]))
352 | gan2_val += gan_loss.item()
353 |
354 | # Topology Loss
355 | tp_loss = tp(fake_data2.x.sum(dim=-1), data.y2.sum(dim=-1))
356 | tp2_val += tp_loss.item()
357 |
358 | # KL Loss
359 | kl_loss = kl.kl_divergence(normal.Normal(fake_data2.x.mean(dim=1), fake_data2.x.std(dim=1)),
360 | normal.Normal(data.y2.mean(dim=1), data.y2.std(dim=1))).sum()
361 | k2_val += kl_loss.item()
362 |
363 | # Identity Loss
364 | loss_G = i_coeff * identity_loss(generator(swapped_data2), data.y2) + g_coeff * gan_loss + kl_coeff * kl_loss
365 | g_val2 += loss_G.item()
366 | mse_l_val2 += msel(generator2(fake_data), data.y2).item()
367 | mae_l_val2 += mael(generator2(fake_data), data.y2).item()
368 |
369 | if opt.tr_st == 'idle':
370 | counter_g += 1
371 | counter_d += 1
372 | if counter_g == 2 * opt.id_e:
373 | counter_g = 0
374 | counter_d = 0
375 |
376 |
377 | print(f'Epoch [{epoch + 1}/{num_epochs}]')
378 | print(f'[Train]: D Loss: {d / total_step:.5f}, G Loss: {g / total_step:.5f} R Loss: {r / total_step:.5f}, F Loss: {f / total_step:.5f}, MSE: {mse_l / total_step:.5f}, MAE: {mae_l / total_step:.5f}')
379 | print(f'[Val]: D Loss: {d_val / val_step:.5f}, G Loss: {g_val / val_step:.5f} R Loss: {r_val / val_step:.5f}, F Loss: {f_val / val_step:.5f}, MSE: {mse_l_val / val_step:.5f}, MAE: {mae_l_val / val_step:.5f}')
380 | print(f'[Train]: D2 Loss: {d2 / total_step:.5f}, G2 Loss: {g2 / total_step:.5f} R2 Loss: {r2 / total_step:.5f}, F2 Loss: {f2 / total_step:.5f}, MSE: {mse_l2 / total_step:.5f}, MAE: {mae_l2 / total_step:.5f}')
381 | print(f'[Val]: D2 Loss: {d_val2 / val_step:.5f}, G2 Loss: {g_val2 / val_step:.5f} R2 Loss: {r_val2 / val_step:.5f}, F2 Loss: {f_val2 / val_step:.5f}, MSE: {mse_l_val2 / val_step:.5f}, MAE: {mae_l_val2 / val_step:.5f}')
382 |
383 | real_losses.append(r / total_step)
384 | fake_losses.append(f / total_step)
385 | mse_losses.append(mse_l / total_step)
386 | mae_losses.append(mae_l / total_step)
387 | real_losses_val.append(r_val / val_step)
388 | fake_losses_val.append(f_val / val_step)
389 | mse_losses_val.append(mse_l_val / val_step)
390 | mae_losses_val.append(mae_l_val / val_step)
391 | real_losses2.append(r2 / total_step)
392 | fake_losses2.append(f2 / total_step)
393 | mse_losses2.append(mse_l2 / total_step)
394 | mae_losses2.append(mae_l2 / total_step)
395 | real_losses_val2.append(r_val2 / val_step)
396 | fake_losses_val2.append(f_val2 / val_step)
397 | mse_losses_val2.append(mse_l_val2 / val_step)
398 | mae_losses_val2.append(mae_l_val2 / val_step)
399 | k1_losses.append(k1_train / total_step)
400 | k2_losses.append(k2_train / total_step)
401 | k1_losses_val.append(k1_val / val_step)
402 | k2_losses_val.append(k2_val / val_step)
403 | tp_losses_1_tr.append(tp1_tr / total_step)
404 | tp_losses_1_val.append(tp1_val / val_step)
405 | tp_losses_2_tr.append(tp2_tr / total_step)
406 | tp_losses_2_val.append(tp2_val / val_step)
407 | gan_losses_1_tr.append(gan1_tr / total_step)
408 | gan_losses_1_val.append(gan1_val / val_step)
409 | gan_losses_2_tr.append(gan2_tr / total_step)
410 | gan_losses_2_val.append(gan2_val / val_step)
411 |
412 | # Plot losses
413 | plot("BCE", "DiscriminatorRealLossTrainSet" + str(fold) + "_exp" + str(opt.exp), real_losses)
414 | plot("BCE", "DiscriminatorRealLossValSet" + str(fold) + "_exp" + str(opt.exp), real_losses_val)
415 | plot("BCE", "DiscriminatorFakeLossTrainSet" + str(fold) + "_exp" + str(opt.exp), fake_losses)
416 | plot("BCE", "DiscriminatorFakeLossValSet" + str(fold) + "_exp" + str(opt.exp), fake_losses_val)
417 | plot("MSE", "GeneratorMSELossTrainSet" + str(fold) + "_exp" + str(opt.exp), mse_losses)
418 | plot("MSE", "GeneratorMSELossValSet" + str(fold) + "_exp" + str(opt.exp), mse_losses_val)
419 | plot("MAE", "GeneratorMAELossTrainSet" + str(fold) + "_exp" + str(opt.exp), mae_losses)
420 | plot("MAE", "GeneratorMAELossValSet" + str(fold) + "_exp" + str(opt.exp), mae_losses_val)
421 | plot("BCE", "Discriminator2RealLossTrainSet" + str(fold) + "_exp" + str(opt.exp), real_losses2)
422 | plot("BCE", "Discriminator2RealLossValSet" + str(fold) + "_exp" + str(opt.exp), real_losses_val2)
423 | plot("BCE", "Discriminator2FakeLossTrainSet" + str(fold) + "_exp" + str(opt.exp), fake_losses2)
424 | plot("BCE", "Discriminator2FakeLossValSet" + str(fold) + "_exp" + str(opt.exp), fake_losses_val2)
425 | plot("MSE", "Generator2MSELossTrainSet" + str(fold) + "_exp" + str(opt.exp), mse_losses2)
426 | plot("MSE", "Generator2MSELossValSet" + str(fold) + "_exp" + str(opt.exp), mse_losses_val2)
427 | plot("MAE", "Generator2MAELossTrainSet" + str(fold) + "_exp" + str(opt.exp), mae_losses2)
428 | plot("MAE", "Generator2MAELossValSet" + str(fold) + "_exp" + str(opt.exp), mae_losses_val2)
429 | plot("KL Loss", "KL_Loss_1_TrainSet" + str(fold) + "_exp" + str(opt.exp), k1_losses)
430 | plot("KL Loss", "KL_Loss_1_ValSet" + str(fold) + "_exp" + str(opt.exp), k1_losses_val)
431 | plot("KL Loss", "KL_Loss_2_TrainSet" + str(fold) + "_exp" + str(opt.exp), k2_losses)
432 | plot("KL Loss", "KL_Loss_2_ValSet" + str(fold) + "_exp" + str(opt.exp), k2_losses_val)
433 | plot("TP Loss", "TP_Loss_1_TrainSet" + str(fold) + "_exp" + str(opt.exp), tp_losses_1_tr)
434 | plot("TP Loss", "TP_Loss_1_ValSet" + str(fold) + "_exp" + str(opt.exp), tp_losses_1_val)
435 | plot("TP Loss", "TP_Loss_2_TrainSet" + str(fold) + "_exp" + str(opt.exp), tp_losses_2_tr)
436 | plot("TP Loss", "TP_Loss_2_ValSet" + str(fold) + "_exp" + str(opt.exp), tp_losses_2_val)
437 | plot("BCE", "GAN_Loss_1_TrainSet" + str(fold) + "_exp" + str(opt.exp), gan_losses_1_tr)
438 | plot("BCE", "GAN_Loss_1_ValSet" + str(fold) + "_exp" + str(opt.exp), gan_losses_1_val)
439 | plot("BCE", "GAN_Loss_2_TrainSet" + str(fold) + "_exp" + str(opt.exp), gan_losses_2_tr)
440 | plot("BCE", "GAN_Loss_2_ValSet" + str(fold) + "_exp" + str(opt.exp), gan_losses_2_val)
441 |
442 | # Save the losses
443 | if gen_mae_losses_tr is None:
444 | gen_mae_losses_tr = mae_losses
445 | disc_real_losses_tr = real_losses
446 | disc_fake_losses_tr = fake_losses
447 | gen_mae_losses_val = mae_losses_val
448 | disc_real_losses_val = real_losses_val
449 | disc_fake_losses_val = fake_losses_val
450 | gen_mae_losses_tr2 = mae_losses2
451 | disc_real_losses_tr2 = real_losses2
452 | disc_fake_losses_tr2 = fake_losses2
453 | gen_mae_losses_val2 = mae_losses_val2
454 | disc_real_losses_val2 = real_losses_val2
455 | disc_fake_losses_val2 = fake_losses_val2
456 | k1_train_s = k1_losses
457 | k2_train_s = k2_losses
458 | k1_val_s = k1_losses_val
459 | k2_val_s = k2_losses_val
460 | tp1_train_s = tp_losses_1_tr
461 | tp2_train_s = tp_losses_2_tr
462 | tp1_val_s = tp_losses_1_val
463 | tp2_val_s = tp_losses_2_val
464 | gan1_train_s = gan_losses_1_tr
465 | gan2_train_s = gan_losses_2_tr
466 | gan1_val_s = gan_losses_1_val
467 | gan2_val_s = gan_losses_2_val
468 | else:
469 | gen_mae_losses_tr = np.vstack([gen_mae_losses_tr, mae_losses])
470 | disc_real_losses_tr = np.vstack([disc_real_losses_tr, real_losses])
471 | disc_fake_losses_tr = np.vstack([disc_fake_losses_tr, fake_losses])
472 | gen_mae_losses_val = np.vstack([gen_mae_losses_val, mae_losses_val])
473 | disc_real_losses_val = np.vstack([disc_real_losses_val, real_losses_val])
474 | disc_fake_losses_val = np.vstack([disc_fake_losses_val, fake_losses_val])
475 | gen_mae_losses_tr2 = np.vstack([gen_mae_losses_tr2, mae_losses2])
476 | disc_real_losses_tr2 = np.vstack([disc_real_losses_tr2, real_losses2])
477 | disc_fake_losses_tr2 = np.vstack([disc_fake_losses_tr2, fake_losses2])
478 | gen_mae_losses_val2 = np.vstack([gen_mae_losses_val2, mae_losses_val2])
479 | disc_real_losses_val2 = np.vstack([disc_real_losses_val2, real_losses_val2])
480 | disc_fake_losses_val2 = np.vstack([disc_fake_losses_val2, fake_losses_val2])
481 | k1_train_s = np.vstack([k1_train_s, k1_losses])
482 | k2_train_s = np.vstack([k2_train_s, k2_losses])
483 | k1_val_s = np.vstack([k1_val_s, k1_losses_val])
484 | k2_val_s = np.vstack([k2_val_s, k2_losses_val])
485 | tp1_train_s = np.vstack([tp1_train_s, tp_losses_1_tr])
486 | tp2_train_s = np.vstack([tp2_train_s, tp_losses_2_tr])
487 | tp1_val_s = np.vstack([tp1_val_s, tp_losses_1_val])
488 | tp2_val_s = np.vstack([tp2_val_s, tp_losses_2_val])
489 | gan1_train_s = np.vstack([gan1_train_s, gan_losses_1_tr])
490 | gan2_train_s = np.vstack([gan2_train_s, gan_losses_2_tr])
491 | gan1_val_s = np.vstack([gan1_val_s, gan_losses_1_val])
492 | gan2_val_s = np.vstack([gan2_val_s, gan_losses_2_val])
493 |
494 | # Save the models
495 | torch.save(generator.state_dict(), "../weights/generator_" + str(fold) + "_" + str(epoch) + "_" + str(opt.exp))
496 | torch.save(discriminator.state_dict(), "../weights/discriminator_" + str(fold) + "_" + str(epoch) + "_" + str(opt.exp))
497 | torch.save(generator2.state_dict(),
498 | "../weights/generator2_" + str(fold) + "_" + str(epoch) + "_" + str(opt.exp))
499 | torch.save(discriminator2.state_dict(),
500 | "../weights/discriminator2_" + str(fold) + "_" + str(epoch) + "_" + str(opt.exp))
501 |
502 | del generator
503 | del discriminator
504 |
505 | del generator2
506 | del discriminator2
507 |
508 | # Save losses
509 | with open("../losses/G_TrainLoss_exp_" + str(opt.exp), "wb") as f:
510 | pickle.dump(gen_mae_losses_tr, f)
511 | with open("../losses/G_ValLoss_exp_" + str(opt.exp), "wb") as f:
512 | pickle.dump(gen_mae_losses_val, f)
513 | with open("../losses/D_TrainRealLoss_exp_" + str(opt.exp), "wb") as f:
514 | pickle.dump(disc_real_losses_tr, f)
515 | with open("../losses/D_TrainFakeLoss_exp_" + str(opt.exp), "wb") as f:
516 | pickle.dump(disc_fake_losses_tr, f)
517 | with open("../losses/D_ValRealLoss_exp_" + str(opt.exp), "wb") as f:
518 | pickle.dump(disc_real_losses_val, f)
519 | with open("../losses/D_ValFakeLoss_exp_" + str(opt.exp), "wb") as f:
520 | pickle.dump(disc_fake_losses_val, f)
521 | with open("../losses/G2_TrainLoss_exp_" + str(opt.exp), "wb") as f:
522 | pickle.dump(gen_mae_losses_tr2, f)
523 | with open("../losses/G2_ValLoss_exp_" + str(opt.exp), "wb") as f:
524 | pickle.dump(gen_mae_losses_val2, f)
525 | with open("../losses/D2_TrainRealLoss_exp_" + str(opt.exp), "wb") as f:
526 | pickle.dump(disc_real_losses_tr2, f)
527 | with open("../losses/D2_TrainFakeLoss_exp_" + str(opt.exp), "wb") as f:
528 | pickle.dump(disc_fake_losses_tr2, f)
529 | with open("../losses/D2_ValRealLoss_exp_" + str(opt.exp), "wb") as f:
530 | pickle.dump(disc_real_losses_val2, f)
531 | with open("../losses/D2_ValFakeLoss_exp_" + str(opt.exp), "wb") as f:
532 | pickle.dump(disc_fake_losses_val2, f)
533 | with open("../losses/GenTotal_Train_exp_" + str(opt.exp), "wb") as f:
534 | pickle.dump(gen_mae_losses_tr + gen_mae_losses_tr2, f)
535 | with open("../losses/GenTotal_Val_exp_" + str(opt.exp), "wb") as f:
536 | pickle.dump(gen_mae_losses_val + gen_mae_losses_val2, f)
537 | with open("../losses/K1_TrainLoss_exp_" + str(opt.exp), "wb") as f:
538 | pickle.dump(k1_train_s, f)
539 | with open("../losses/K1_ValLoss_exp_" + str(opt.exp), "wb") as f:
540 | pickle.dump(k2_train_s, f)
541 | with open("../losses/K2_TrainLoss_exp_" + str(opt.exp), "wb") as f:
542 | pickle.dump(k1_val_s, f)
543 | with open("../losses/K2_ValLoss_exp_" + str(opt.exp), "wb") as f:
544 | pickle.dump(k2_val_s, f)
545 | with open("../losses/TP1_TrainLoss_exp_" + str(opt.exp), "wb") as f:
546 | pickle.dump(tp1_train_s, f)
547 | with open("../losses/TP1_ValLoss_exp_" + str(opt.exp), "wb") as f:
548 | pickle.dump(tp2_train_s, f)
549 | with open("../losses/TP2_TrainLoss_exp_" + str(opt.exp), "wb") as f:
550 | pickle.dump(tp1_val_s, f)
551 | with open("../losses/TP2_ValLoss_exp_" + str(opt.exp), "wb") as f:
552 | pickle.dump(tp2_val_s, f)
553 | with open("../losses/GAN1_TrainLoss_exp_" + str(opt.exp), "wb") as f:
554 | pickle.dump(gan1_train_s, f)
555 | with open("../losses/GAN1_ValLoss_exp_" + str(opt.exp), "wb") as f:
556 | pickle.dump(gan2_train_s, f)
557 | with open("../losses/GAN2_TrainLoss_exp_" + str(opt.exp), "wb") as f:
558 | pickle.dump(gan1_val_s, f)
559 | with open("../losses/GAN2_ValLoss_exp_" + str(opt.exp), "wb") as f:
560 | pickle.dump(gan2_val_s, f)
561 |
562 | print(f"Training Complete for experiment {opt.exp}!")
563 |
564 |
--------------------------------------------------------------------------------