├── .gitignore
├── README.md
└── source
├── acc.svg
├── loss.svg
├── models
├── classifier.py
├── diffpool.py
└── graphsage.py
├── notebooks
└── dataset_stats.ipynb
├── train.py
└── utils
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ### directories
2 |
3 | data/
4 | **/.idea
5 | **/.ipynb_checkpoints
6 | **/__pycache__
7 |
8 |
9 | ### files
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Graph Convolution + Pooling NN PyTorch Implementation
2 |
3 | ## Generating a graph classification NN in PyTorch, combining Graph Convolutions and Graph Pooling modules
4 |
5 | Work based on:
6 |
7 | - [GCNs](https://arxiv.org/pdf/1609.02907.pdf)
8 | - [GraphSAGE](https://arxiv.org/pdf/1706.02216.pdf)
9 | - [DiffPool](https://arxiv.org/pdf/1806.08804.pdf)
10 | - [Graph U-net](https://openreview.net/pdf?id=HJePRoAct7)
11 |
--------------------------------------------------------------------------------
/source/acc.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
1059 |
--------------------------------------------------------------------------------
/source/loss.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
1049 |
--------------------------------------------------------------------------------
/source/models/classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | from models.diffpool import DiffPool
7 | from models.graphsage import GraphSAGE
8 |
9 | class Classifier(nn.Module):
10 |
11 | def __init__(self, device="cuda:0"):
12 | super(Classifier, self).__init__()
13 | self.device = device
14 | self.sage1 = GraphSAGE(7, 14, device=self.device)
15 | self.pool1 = DiffPool(14, 14, device=self.device)
16 | self.sage2 = GraphSAGE(14, 28, device=self.device)
17 | self.pool2 = DiffPool(28, 1, final_layer=True, device=self.device)
18 | self.linear1 = nn.Linear(28, 10)
19 | self.linear2 = nn.Linear(10, 2)
20 |
21 | def forward(self, x, a):
22 | x = self.sage1(x, a)
23 | x, a = self.pool1(x, a)
24 | x = self.sage2(x, a)
25 | x, a = self.pool2(x, a)
26 | y = self.linear1(x.squeeze(0))
27 | y = self.linear2(y)
28 | return x, a, y
29 |
--------------------------------------------------------------------------------
/source/models/diffpool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 | from models.graphsage import GraphSAGE
6 |
7 | class DiffPool(nn.Module):
8 |
9 | def __init__(self, feature_size, output_dim, device="cuda:0", final_layer=False):
10 | super(DiffPool, self).__init__()
11 | self.device = device
12 | self.feature_size = feature_size
13 | self.output_dim = output_dim
14 | self.embed = GraphSAGE(self.feature_size, self.feature_size, device=self.device)
15 | self.pool = GraphSAGE(self.feature_size, self.output_dim, device=self.device)
16 | self.final_layer = final_layer
17 |
18 | def forward(self, x, a):
19 | z = self.embed(x, a)
20 | if self.final_layer:
21 | s = torch.ones(x.size(0), self.output_dim, device=self.device)
22 | else:
23 | s = F.softmax(self.pool(x, a), dim=1)
24 | x_new = s.t() @ z
25 | a_new = s.t() @ a @ s
26 | return x_new, a_new
27 |
28 |
--------------------------------------------------------------------------------
/source/models/graphsage.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 | class GraphSAGE(nn.Module):
6 |
7 | def __init__(self, input_feat, output_feat, device="cuda:0", normalize=True):
8 | super(GraphSAGE, self).__init__()
9 | self.device = device
10 | self.normalize = normalize
11 | self.input_feat = input_feat
12 | self.output_feat = output_feat
13 | self.linear = nn.Linear(self.input_feat, self. output_feat)
14 | self.layer_norm = nn.LayerNorm(self.output_feat) # elementwise_affine=False
15 | nn.init.xavier_uniform_(self.linear.weight)
16 |
17 | def aggregate_convolutional(self, x, a):
18 | eye = torch.eye(a.shape[0], dtype=torch.float, device=self.device)
19 | a = a + eye
20 | h_hat = a @ x
21 |
22 | return h_hat
23 |
24 | def forward(self, x, a):
25 | h_hat = self.aggregate_convolutional(x, a)
26 | h = F.relu(self.linear(h_hat))
27 | if self.normalize:
28 | # h = F.normalize(h, p=2, dim=1) # Normalize edge embeddings
29 | h = self.layer_norm(h) # Normalize layerwise (mean=0, std=1)
30 |
31 | return h
32 |
--------------------------------------------------------------------------------
/source/notebooks/dataset_stats.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import argparse\n",
10 | "\n",
11 | "import torch\n",
12 | "import torch.nn as nn\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "from torch.nn import functional as F\n",
15 | "from torchvision.utils import make_grid\n",
16 | "from torchvision.utils import save_image\n",
17 | "\n",
18 | "import os\n",
19 | "\n",
20 | "import pickle\n",
21 | "from collections import Counter\n",
22 | "import numpy as np"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "# ds_name = '../../data/collab.graph' # Without node features\n",
32 | "ds_name = '../../data/mutag.graph' # With node features"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 3,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "def load_data(ds_name):\n",
42 | " f = open(ds_name, \"rb\")\n",
43 | " print(\"Found dataset:\", ds_name)\n",
44 | " data = pickle.load(f, encoding=\"latin1\")\n",
45 | " graph_data = data[\"graph\"]\n",
46 | " labels = data[\"labels\"]\n",
47 | " labels = np.array(labels, dtype = np.float)\n",
48 | " return graph_data, labels"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 4,
54 | "metadata": {},
55 | "outputs": [
56 | {
57 | "name": "stdout",
58 | "output_type": "stream",
59 | "text": [
60 | "Found dataset: ../../data/mutag.graph\n"
61 | ]
62 | }
63 | ],
64 | "source": [
65 | "graphs, labels = load_data(ds_name)"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 5,
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "Dataset: ../../data/mutag.graph \n",
78 | "Number of Graphs: 188 \n",
79 | "Label distribution: Counter({1.0: 125, -1.0: 63})\n",
80 | "\n",
81 | "Mean #nodes: 17.930851063829788 \n",
82 | "Median #nodes: 17.5 \n",
83 | "Max #nodes: 28 \n",
84 | "Min #nodes: 10 \n",
85 | "Total #nodes: 3371\n",
86 | "\n",
87 | "Mean #edges: 2.2076535152773658 \n",
88 | "Median #edges: 2.0 \n",
89 | "Max #edges: 4 \n",
90 | "Min #edges: 1 \n",
91 | "Total #edges: 7442\n",
92 | "\n",
93 | "Mean #features_len: 1.0 \n",
94 | "Median #features_len: 1.0 \n",
95 | "Max #features_len: 1 \n",
96 | "Min #features_len: 1 \n",
97 | "Total #features_len: 3371\n",
98 | "\n",
99 | "Number of nodes with features: 3371\n",
100 | "Features distribution: Counter({(3,): 2395, (7,): 593, (6,): 345, (2,): 23, (4,): 12, (1,): 2, (5,): 1})\n"
101 | ]
102 | }
103 | ],
104 | "source": [
105 | "print (\"Dataset: %s \\nNumber of Graphs: %s \\nLabel distribution: %s\"%(ds_name, len(graphs), Counter(labels)))\n",
106 | "avg_edges = []\n",
107 | "avg_nodes = []\n",
108 | "n_features = 0\n",
109 | "avg_features = []\n",
110 | "features = []\n",
111 | "for gidxs, nodes in graphs.items():\n",
112 | " for n in nodes:\n",
113 | " avg_edges.append(len(nodes[n]['neighbors']))\n",
114 | " if nodes[n]['label'] != '':\n",
115 | " n_features += 1\n",
116 | " avg_features.append(len(nodes[n]['label']))\n",
117 | " features.append(nodes[n]['label'])\n",
118 | " else:\n",
119 | " avg_features.append(0)\n",
120 | " features.append(None)\n",
121 | " avg_nodes.append(len(nodes))\n",
122 | "print(\"\\nMean #nodes: %s \\nMedian #nodes: %s \\nMax #nodes: %s \\nMin #nodes: %s \\nTotal #nodes: %s\"%(np.mean(avg_nodes), np.median(avg_nodes), max(avg_nodes), min(avg_nodes), sum(avg_nodes)))\n",
123 | "print(\"\\nMean #edges: %s \\nMedian #edges: %s \\nMax #edges: %s \\nMin #edges: %s \\nTotal #edges: %s\"%(np.mean(avg_edges), np.median(avg_edges), max(avg_edges), min(avg_edges), sum(avg_edges)))\n",
124 | "print(\"\\nMean #features_len: %s \\nMedian #features_len: %s \\nMax #features_len: %s \\nMin #features_len: %s \\nTotal #features_len: %s\"%(np.mean(avg_features), np.median(avg_features), max(avg_features), min(avg_features), sum(avg_features)))\n",
125 | "print(\"\\nNumber of nodes with features: %s\"%(n_features))\n",
126 | "print(\"Features distribution: %s\"%(Counter(features)))"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 6,
132 | "metadata": {},
133 | "outputs": [
134 | {
135 | "name": "stdout",
136 | "output_type": "stream",
137 | "text": [
138 | "Example of Graph keys (e.g. node_ids):\n",
139 | " dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])\n"
140 | ]
141 | }
142 | ],
143 | "source": [
144 | "graph = graphs[5]\n",
145 | "nodes = graph.keys()\n",
146 | "print(\"Example of Graph keys (e.g. node_ids):\\n\", nodes)"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 7,
152 | "metadata": {},
153 | "outputs": [
154 | {
155 | "name": "stdout",
156 | "output_type": "stream",
157 | "text": [
158 | "Example of Node:\n",
159 | " {'neighbors': array([ 1, 13], dtype=uint8), 'label': (3,)}\n"
160 | ]
161 | }
162 | ],
163 | "source": [
164 | "node = graph[0]\n",
165 | "print(\"Example of Node:\\n\",node)"
166 | ]
167 | }
168 | ],
169 | "metadata": {
170 | "kernelspec": {
171 | "display_name": "Python 3",
172 | "language": "python",
173 | "name": "python3"
174 | },
175 | "language_info": {
176 | "codemirror_mode": {
177 | "name": "ipython",
178 | "version": 3
179 | },
180 | "file_extension": ".py",
181 | "mimetype": "text/x-python",
182 | "name": "python",
183 | "nbconvert_exporter": "python",
184 | "pygments_lexer": "ipython3",
185 | "version": "3.6.4"
186 | }
187 | },
188 | "nbformat": 4,
189 | "nbformat_minor": 2
190 | }
191 |
--------------------------------------------------------------------------------
/source/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path
4 | import time
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import matplotlib.pyplot as plt
11 | from torch.nn import functional as F
12 | from torchvision.utils import make_grid
13 | from torchvision.utils import save_image
14 |
15 | from models.classifier import Classifier
16 | from models.graphsage import GraphSAGE
17 | from models.diffpool import DiffPool
18 | from utils.utils import *
19 |
20 | SEED = 42
21 | NORMALIZE = True
22 |
23 | def compute_accuracy(pred, target, device="cuda:0"):
24 | pred_labels = torch.stack(pred, dim=0).to(device=device)
25 | acc = (pred_labels.long() == target.long()).float().mean()
26 |
27 | return acc
28 |
29 | def run_epoch(classifier, optimizer, criterion, x_data, a_data, t_data, eval=False, device="cuda:0"):
30 | data_len = x_data.size(0)
31 | pred = []
32 | losses = []
33 | scores = []
34 | if eval:
35 | classifier.eval()
36 | else:
37 | classifier.train()
38 |
39 | for i in range(data_len):
40 | optimizer.zero_grad()
41 | x = x_data[i]
42 | a = a_data[i]
43 | t = t_data[i]
44 |
45 | _, _, y = classifier(x, a)
46 | loss = criterion(y.unsqueeze(0), t.long().unsqueeze(0))
47 | if not eval:
48 | loss.backward()
49 | optimizer.step()
50 | pred.append(y.argmax())
51 | losses.append(loss)
52 | scores.append(y)
53 | acc_epoch = compute_accuracy(pred, t_data, device=device).item()
54 | loss_epoch = torch.FloatTensor(losses).mean().item()
55 |
56 | return acc_epoch, loss_epoch
57 |
58 |
59 | def main():
60 | print(ARGS)
61 | start_time = time.time()
62 |
63 | device = torch.device(ARGS.device)
64 | np.random.seed(SEED)
65 | torch.manual_seed(SEED)
66 | if ARGS.device != "cpu":
67 | torch.cuda.manual_seed(SEED)
68 | normalize = NORMALIZE
69 |
70 | dataset_helper = DatasetHelper()
71 | dataset_helper.load_dataset(ARGS.dataset, device=device, seed=SEED, normalize=normalize)
72 | (x_train, a_train, labels_train) = dataset_helper.train
73 | (x_valid, a_valid, labels_valid) = dataset_helper.valid
74 | # feature_size = dataset_helper.feature_size
75 | print("Imported dataset, generated train and validation splits, took: {:.3f}s".format(time.time() - start_time))
76 |
77 | # x1 = x_train[0]
78 | # a1 = a_train[0]
79 | # t1 = labels_train[0]
80 | # # t1_onehot = to_onehot_labels(t1)
81 | # classifier = Classifier(device=device).to(device=device)
82 | # _ = classifier(x1, a1)
83 |
84 | # # Try GraphSAGE
85 | # gcn = GraphSAGE(feature_size, feature_size*2, device=device, normalize=normalize)to(device=device)
86 | # gcn.train()
87 | # weights = gcn.linear.weight
88 | # z1 = gcn(x1, a1)
89 | # print()
90 | #
91 | # # Try DiffPool
92 | # diffpool = DiffPool(feature_size, x1.size(0)//2, device=device)to(device=device)
93 | # diffpool.train()
94 | # x1_new, a1_new = diffpool(x1, a1)
95 | # print()
96 | # input("remove test")
97 |
98 | # Try Classifier
99 |
100 | classifier = Classifier(device=device).to(device=device)
101 | optimizer = optim.Adam(classifier.parameters())
102 | criterion = nn.CrossEntropyLoss()
103 |
104 | assert x_train.size(0) == dataset_helper.train_size == labels_train.size(0)
105 | assert x_valid.size(0) == dataset_helper.valid_size == labels_valid.size(0)
106 | assert (labels_train.sum() + labels_valid.sum()) < 1000000
107 |
108 | print("\nStarted training")
109 | acc_valid, loss_valid = run_epoch(classifier, optimizer, criterion, x_valid, a_valid, labels_valid, eval=True,
110 | device=device)
111 | print("Epoch {:.0f} | Acc (Valid): {:.3f} | Loss (Valid): {:.3f} |".format(0, acc_valid, loss_valid))
112 |
113 | measures = {
114 | "acc" : {
115 | "train" : [],
116 | "valid" : []
117 | },
118 | "loss" : {
119 | "train" : [],
120 | "valid" : []
121 | }
122 | }
123 |
124 | for e in range(ARGS.epochs):
125 |
126 | acc_train, loss_train = run_epoch(classifier, optimizer, criterion, x_train, a_train, labels_train, eval=False, device=device)
127 | acc_valid, loss_valid = run_epoch(classifier, optimizer, criterion, x_valid, a_valid, labels_valid, eval=True, device=device)
128 | print("Epoch {:.0f} | Acc (Train/Valid): {:.3f}/{:.3f} | Loss (Train/Valid): {:.3f}/{:.3f} |"
129 | .format(e + 1, acc_train, acc_valid, loss_train, loss_valid))
130 |
131 | measures["acc"]["train"].append(acc_train)
132 | measures["acc"]["valid"].append(acc_valid)
133 | measures["loss"]["train"].append(loss_train)
134 | measures["loss"]["valid"].append(loss_valid)
135 |
136 | if e % 10 == 0:
137 | for k in measures.keys():
138 | generate_plot(measures[k]["train"], measures[k]["valid"], title=k)
139 |
140 | print()
141 |
142 |
143 |
144 |
145 | if __name__ == "__main__":
146 | parser = argparse.ArgumentParser()
147 | parser.add_argument('--epochs', default=100, type=int,
148 | help='max number of epochs')
149 | parser.add_argument('--batch_size', default=16, type=int,
150 | help='size of each batch')
151 | parser.add_argument('--lr_rate', default=1e-4, type=float,
152 | help='learning rate')
153 | parser.add_argument('--device', default="cuda:0", type=str,
154 | help='training device')
155 | parser.add_argument('--dataset', default="../data/mutag.graph", type=str,
156 | help='dataset path')
157 |
158 | ARGS = parser.parse_args()
159 |
160 | main()
161 |
162 |
--------------------------------------------------------------------------------
/source/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import errno
4 | import pickle
5 |
6 | import matplotlib.pyplot as plt
7 |
8 | import numpy as np
9 | import torch
10 | from torch.nn import functional as F
11 |
12 | KNOWN_DATASETS = {"../data/mutag.graph"}
13 |
14 |
15 | class DatasetHelper(object):
16 |
17 | def __init__(self):
18 | self.train = None
19 | self.valid = None
20 | self.feature_size = -1
21 | self.n_graphs = -1
22 | self.n_nodes = -1
23 | self.train_size = -1
24 | self.valid_size = -1
25 | self.onehot = None
26 | self.labels_set = None
27 | self.TRAIN_RATIO = 0.8
28 |
29 |
30 | def read_file(self, ds_name):
31 |
32 | if not os.path.isfile(ds_name):
33 | raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), ds_name)
34 | if ds_name not in KNOWN_DATASETS:
35 | raise NotImplementedError("Dataset unknown to 'load_dataset()' in 'utils/utils.py'")
36 |
37 | f = open(ds_name, "rb")
38 | print("Found dataset:", ds_name)
39 | data = pickle.load(f, encoding="latin1")
40 | graphs = data["graph"]
41 | labels = np.array(data["labels"], dtype=np.float)
42 |
43 | return graphs, labels
44 |
45 | def normalize_label(self, l):
46 | # from -1,1 to 0,1
47 | n = (l + 1) // 2
48 | if n != 0 and n != 1:
49 | raise ValueError
50 | return n
51 |
52 |
53 | def load_dataset(self, ds_name, device="cuda:0", seed=42, normalize=True, onehot=True):
54 |
55 | graphs, labels = self.read_file(ds_name)
56 | self.labels_set = set(labels)
57 |
58 | # Compute shape dimensions n_graphs, n_nodes, features_size
59 | self.n_graphs = len(graphs)
60 | self.n_nodes = -1 # Max number of nodes among all of the graphs
61 | self.onehot = onehot
62 |
63 | # Find the feature size (scalar or onehot)
64 | if self.onehot:
65 | # Find the size of the onehot vector for the features (i.e.: the maximum value present in the dataset)
66 | self.feature_size = 0 # Feature array size for each node is onehot vector
67 | # min_value = 1000
68 | for i in range(self.n_graphs):
69 | for j in range(len(graphs[i])):
70 | for _, d in enumerate(graphs[i][j]['label']):
71 | self.feature_size = max(self.feature_size, d)
72 | # min_value = min(min_value, d)
73 | else:
74 | self.feature_size = len(graphs[0][0]['label']) # Feature array size for each node
75 |
76 | # Find number of nodes
77 | for gidxs, graph in graphs.items():
78 | self.n_nodes = max(self.n_nodes, len(graph))
79 | assert self.n_nodes > 0, "Apparently,there are no nodes in these graphs"
80 |
81 | # Generate train and valid splits
82 | torch.manual_seed(seed)
83 | shuffled_idx = torch.randperm(self.n_graphs)
84 | self.train_size = int(self.n_graphs * self.TRAIN_RATIO)
85 | self.valid_size = self.n_graphs - self.train_size
86 | train_idx = shuffled_idx[:self.train_size]
87 | valid_idx = shuffled_idx[self.train_size:]
88 |
89 | # Generate PyTorch tensors for train
90 | a_train = torch.zeros((self.train_size, self.n_nodes, self.n_nodes), dtype=torch.float, device=device)
91 | x_train = torch.zeros((self.train_size, self.n_nodes, self.feature_size), dtype=torch.float, device=device)
92 | labels_train = torch.LongTensor(self.train_size).to(device=device)
93 | for i in range(self.train_size):
94 | idx = train_idx[i].item()
95 | labels_train[i] = self.normalize_label(labels[idx])
96 | for j in range(len(graphs[idx])):
97 | for k in graphs[idx][j]['neighbors']:
98 | a_train[i, j, k] = 1
99 | for k, d in enumerate(graphs[idx][j]['label']):
100 | if self.onehot:
101 | x_train[i, j, :] = to_onehot(d, self.feature_size, device)
102 | else:
103 | x_train[i, j, k] = float(d)
104 |
105 | # Generate PyTorch tensors for valid
106 | a_valid = torch.zeros((self.valid_size, self.n_nodes, self.n_nodes), dtype=torch.float, device=device)
107 | x_valid = torch.zeros((self.valid_size, self.n_nodes, self.feature_size), dtype=torch.float, device=device)
108 | labels_valid = torch.LongTensor(self.valid_size).to(device=device)
109 | for i in range(self.valid_size):
110 | idx = valid_idx[i].item()
111 | labels_valid[i] = self.normalize_label(labels[idx])
112 | for j in range(len(graphs[idx])):
113 | for k in graphs[idx][j]['neighbors']:
114 | a_valid[i, j, k] = 1
115 | for k, d in enumerate(graphs[idx][j]['label']):
116 | if self.onehot:
117 | x_valid[i, j, :] = to_onehot(d, self.feature_size, device)
118 | else:
119 | x_valid[i, j, k] = float(d)
120 |
121 | if normalize:
122 | x_train = F.normalize(x_train, p=2, dim=1)
123 | x_valid = F.normalize(x_valid, p=2, dim=1)
124 |
125 | self.train = (x_train, a_train, labels_train)
126 | self.valid = (x_valid, a_valid, labels_valid)
127 |
128 |
129 | def to_onehot(x, size, device="cuda:0"):
130 | t = torch.zeros(size, device=device)
131 | t[x - 1] = 1 # Since the minimum value is 1, we index -1 the features because we don't need element 0
132 | return t
133 |
134 | def to_onehot_labels(x, device="cuda:0"):
135 | t = torch.zeros(2, device=device)
136 | if x.item() == 1:
137 | t[1] = 1
138 | else:
139 | t[0] = 1
140 | return t
141 |
142 | def generate_plot(train, valid, title="Insert Title here"):
143 | plt.figure()
144 | plt.plot(valid, label="Validation")
145 | plt.plot(train, label="Training")
146 | plt.xlabel("Epoch")
147 | # plt.ylabel()
148 | plt.legend()
149 | plt.title(title)
150 | # plt.show()
151 | plt.savefig("./"+title+".svg", format='svg', dpi=1000)
152 | return
153 |
--------------------------------------------------------------------------------