├── Image ├── README.md └── workflow (1).jpg ├── README.md ├── LICENSE └── Model_Running.ipynb /Image/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Image/workflow (1).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PengyuanLiu1993/GSL-sidewalk-comfort/HEAD/Image/workflow (1).jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GSL-Sidewalk 2 | 3 | This repo is for our paper "Towards Human-centric Digital Twins: Leveraging Computer Vision and Graph Models to Predict Outdoor Comfort" published in [Sustainable Cities and Society](https://doi.org/10.1016/j.scs.2023.104480) 4 | 5 | 6 | 7 | # Prerequisite Pacakages: 8 | 9 | 1. [Pytorch](https://pytorch.org/) 10 | 2. [Deep Graph Library](https://www.dgl.ai/) 11 | 3. [networkx](https://networkx.org/) 12 | 4. [pandas](https://pandas.pydata.org/) and [geopandas](https://geopandas.org/) 13 | 5. [numpy](https://numpy.org/) 14 | 6. [jenkspy](https://pypi.org/project/jenkspy/) 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 PengyuanLiu1993 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Model_Running.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | }, 15 | "accelerator": "GPU", 16 | "gpuClass": "standard" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": { 23 | "colab": { 24 | "base_uri": "https://localhost:8080/" 25 | }, 26 | "id": "4OYe8RFFSdLS", 27 | "outputId": "9c1894d3-4817-4743-f1f3-07da36c50943" 28 | }, 29 | "outputs": [ 30 | { 31 | "output_type": "stream", 32 | "name": "stdout", 33 | "text": [ 34 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", 35 | "Requirement already satisfied: dgl in /usr/local/lib/python3.7/dist-packages (0.9.0)\n", 36 | "Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.7/dist-packages (from dgl) (1.21.6)\n", 37 | "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from dgl) (2.23.0)\n", 38 | "Requirement already satisfied: networkx>=2.1 in /usr/local/lib/python3.7/dist-packages (from dgl) (2.6.3)\n", 39 | "Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.7/dist-packages (from dgl) (5.9.2)\n", 40 | "Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from dgl) (1.7.3)\n", 41 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from dgl) (4.64.0)\n", 42 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->dgl) (2022.6.15)\n", 43 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->dgl) (1.24.3)\n", 44 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->dgl) (2.10)\n", 45 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->dgl) (3.0.4)\n", 46 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", 47 | "Requirement already satisfied: jenkspy in /usr/local/lib/python3.7/dist-packages (0.3.1)\n", 48 | "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from jenkspy) (1.21.6)\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "!pip install dgl\n", 54 | "!pip install jenkspy" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "source": [ 60 | "import torch\n", 61 | "import torch.nn as nn\n", 62 | "from torch.autograd import Variable\n", 63 | "import numpy as np\n", 64 | "import random\n", 65 | "from jenkspy import JenksNaturalBreaks\n", 66 | "\n", 67 | "import os\n", 68 | "import pandas as pd\n", 69 | "from torch.utils.data import Dataset\n", 70 | "from torch.utils.data import DataLoader\n", 71 | "import dgl\n", 72 | "from dgl.nn import SAGEConv\n", 73 | "from sklearn import metrics\n", 74 | "import matplotlib.pyplot as plt\n", 75 | "import seaborn as sn\n", 76 | "import warnings\n", 77 | "import scipy.stats as stats\n", 78 | "\n", 79 | "warnings.filterwarnings('ignore')\n", 80 | "random.seed(1)" 81 | ], 82 | "metadata": { 83 | "id": "luPZ70G2SiIM" 84 | }, 85 | "execution_count": 2, 86 | "outputs": [] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "source": [ 91 | "class MyLSTM(nn.Module):\n", 92 | " def __init__(self, input_size, hidden_size, num_layers):\n", 93 | " super(MyLSTM, self).__init__()\n", 94 | " self.hidden_size = hidden_size\n", 95 | " self.num_layers = num_layers\n", 96 | " self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)\n", 97 | "\n", 98 | "\n", 99 | " def forward(self, x):\n", 100 | " # x shape (batch, time_step, input_size)\n", 101 | " # out shape (batch, time_step, output_size)\n", 102 | " # h_n shape (n_layers, batch, hidden_size)\n", 103 | " # h_c shape (n_layers, batch, hidden_size)\n", 104 | " # 初始化hidden和memory cell参数\n", 105 | " h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)\n", 106 | " c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)\n", 107 | "\n", 108 | " # forward propagate lstm\n", 109 | " out, (h_n, h_c) = self.lstm(x, (h0, c0))\n", 110 | "\n", 111 | " # 选取最后一个时刻的输出\n", 112 | " out = out[:, -1, :]\n", 113 | " return out" 114 | ], 115 | "metadata": { 116 | "id": "XCFZzcT7SunJ" 117 | }, 118 | "execution_count": 3, 119 | "outputs": [] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "source": [ 124 | "class MyDataset(Dataset):\n", 125 | " '''\n", 126 | " 自定义数据集\n", 127 | " '''\n", 128 | " def __init__(self, features, labels, graph):\n", 129 | " self.features = features\n", 130 | " self.labels = labels\n", 131 | " self.graph = graph\n", 132 | " # self.altitudes = altitudes\n", 133 | "\n", 134 | " def __len__(self):\n", 135 | " return len(self.features)\n", 136 | "\n", 137 | " def __getitem__(self, idx):\n", 138 | " feature = self.features[idx]\n", 139 | " graph = self.graph[idx]\n", 140 | " label = self.labels[idx]\n", 141 | " # altitude = self.altitudes[idx]\n", 142 | " return feature, graph, label\n", 143 | "\n", 144 | "\n", 145 | "def get_each_person(data):\n", 146 | " '''\n", 147 | " 将所有数据按照85/人划分\n", 148 | " :param data:\n", 149 | " :return:\n", 150 | " '''\n", 151 | " i_count = 0\n", 152 | " all_person, one_person = [], []\n", 153 | " for index, values in data.iterrows():\n", 154 | " i_count += 1\n", 155 | " if i_count % 85 == 0:\n", 156 | " one_person.append(values)\n", 157 | " all_person.append(one_person)\n", 158 | " one_person = []\n", 159 | " else:\n", 160 | " one_person.append(values)\n", 161 | " return all_person\n", 162 | "\n", 163 | "\n", 164 | "def get_label_class(values):\n", 165 | " '''\n", 166 | " 归一化后的标签离散处理\n", 167 | " :param values: 归一化后的标签值\n", 168 | " :return:\n", 169 | " '''\n", 170 | "\n", 171 | " # return int(values[0]*9)\n", 172 | "\n", 173 | " bins = [0.5, 0.6, 1.1] # 暂时划分为三个等级\n", 174 | " for i in range(len(bins)):\n", 175 | " if bins[i] > values[0]:\n", 176 | " return i\n", 177 | "\n", 178 | "\n", 179 | "def get_graph(graph):\n", 180 | " '''\n", 181 | " 构图\n", 182 | " :param graph: 矩阵14*14\n", 183 | " :return:\n", 184 | " '''\n", 185 | " index = [[], []]\n", 186 | " i = 0\n", 187 | " for line in graph:\n", 188 | " j = 0\n", 189 | " for node in line:\n", 190 | " if i == j:\n", 191 | " break\n", 192 | " if node == 1:\n", 193 | " index[0].append(i)\n", 194 | " index[1].append(j)\n", 195 | " j+=1\n", 196 | " i += 1\n", 197 | " g = dgl.graph((index[0], index[1]))\n", 198 | " g = dgl.add_self_loop(g)\n", 199 | " return g\n", 200 | "\n", 201 | "\n", 202 | "def get_altitude_gap(data_edm):\n", 203 | " altitude = list(data_edm['dem'].values)\n", 204 | " gaps = [altitude[-1]-altitude[0]]\n", 205 | " for i in range(len(altitude)-1):\n", 206 | " gaps.append(altitude[i+1]-altitude[i])\n", 207 | " zscores = stats.zscore(gaps)\n", 208 | " return zscores.astype(np.float32)\n", 209 | "\n", 210 | "\n", 211 | "def read_data(path_fea, path_label, path_graph):\n", 212 | " '''\n", 213 | " 加载原始数据\n", 214 | " :param path_fea:\n", 215 | " :param path_label:\n", 216 | " :param path_graph:\n", 217 | " :return:\n", 218 | " '''\n", 219 | " features, labels, graphs, altitudes = [], [], [], []\n", 220 | " data_fea = pd.read_csv(path_fea)\n", 221 | " data_label = pd.read_csv(path_label, header=None).T\n", 222 | " data_graph = pd.DataFrame(np.load(path_graph).reshape((len(data_label),-1)))\n", 223 | " # data_edm = pd.read_csv(path_dem, index_col=0)\n", 224 | "\n", 225 | " per_fea = get_each_person(data_fea)\n", 226 | " per_label = get_each_person(data_label)\n", 227 | " per_graph = get_each_person(data_graph)\n", 228 | " # altitude_gap = get_altitude_gap(data_edm)\n", 229 | "\n", 230 | " time_window_node, time_window_lstm = 5, 5\n", 231 | " for p in range(len(per_fea)):\n", 232 | " insert_0 = per_fea[p][0]\n", 233 | " for _ in range(time_window_node):\n", 234 | " per_fea[p].insert(0, insert_0)\n", 235 | "\n", 236 | " data_label = pd.DataFrame(per_label[p])\n", 237 | " data_label_norm = (data_label - data_label.min()) / (data_label.max() - data_label.min())\n", 238 | "\n", 239 | " # data_label_norm = pd.qcut(x=data_label_norm[0], q=3, labels=range(0,3), duplicates=\"drop\")\n", 240 | " jnb =JenksNaturalBreaks(3)\n", 241 | " jnb.fit(list(data_label_norm[0].values))\n", 242 | " data_label_norm = jnb.labels_\n", 243 | " data_fea = pd.DataFrame(per_fea[p])\n", 244 | " data_graph = pd.DataFrame(per_graph[p])\n", 245 | "\n", 246 | " lstm_fea, graph_ = [], []\n", 247 | " for i in range(len(data_fea)-time_window_node):\n", 248 | " fea = data_fea.iloc[i:i+time_window_node].values\n", 249 | " lstm_fea.append(fea.astype(np.float32))\n", 250 | "\n", 251 | " for i in range(len(lstm_fea)-time_window_lstm):\n", 252 | " features.append(np.array(lstm_fea[i:i+time_window_lstm]).astype(np.float32))\n", 253 | " graphs.append(data_graph.iloc[i:i+time_window_lstm].values.reshape(time_window_lstm, 14,-1))\n", 254 | " # altitudes.append(altitude_gap[i:i + time_window_lstm])\n", 255 | " # class_ = get_label_class(data_label_norm.iloc[i+time_window_lstm])\n", 256 | " class_ = data_label_norm[i+time_window_lstm]\n", 257 | " labels.append(class_)\n", 258 | "\n", 259 | " # features, labels, graphs, altitudes = shuffle(features, labels, graphs, altitudes)\n", 260 | " features, labels, graphs = shuffle(features, labels, graphs)\n", 261 | "\n", 262 | " return features, labels, graphs\n", 263 | "\n", 264 | "\n", 265 | "def shuffle(features, labels, graphs):\n", 266 | " '''\n", 267 | " 打乱顺序\n", 268 | " :param features:\n", 269 | " :param labels:\n", 270 | " :param graphs:\n", 271 | " :return:\n", 272 | " '''\n", 273 | " new_features, new_labels, new_graphs = [], [], []\n", 274 | " index = [i for i in range(len(features))]\n", 275 | " random.shuffle(index)\n", 276 | " for i in index:\n", 277 | " new_features.append(features[i])\n", 278 | " new_labels.append(labels[i])\n", 279 | " new_graphs.append(graphs[i])\n", 280 | " # new_altitudes.append(altitudes[i])\n", 281 | "\n", 282 | " return new_features, new_labels, new_graphs\n", 283 | "\n", 284 | "\n", 285 | "\n", 286 | "class My_model(nn.Module):\n", 287 | " def __init__(self):\n", 288 | " super(My_model, self).__init__()\n", 289 | " self.graph_embedding = SAGEConv(5, 128, 'pool')\n", 290 | " self.flat = nn.Linear(128*14, 128)\n", 291 | " self.lstm = MyLSTM(input_size=128, hidden_size=128, num_layers=1)\n", 292 | " self.classify = nn.Linear(128, 3)\n", 293 | " # self.altitude_emd = nn.Linear(8, 128)\n", 294 | "\n", 295 | " # self.emb = nn.Linear(12*5, 128)\n", 296 | "\n", 297 | " # self.ac = nn.Sigmoid()\n", 298 | " self.softmax = nn.Softmax()\n", 299 | "\n", 300 | " def forward(self, features, graph):\n", 301 | " features = features.permute(0,1,3,2)\n", 302 | " batch, time_step, node, channel = features.shape\n", 303 | "\n", 304 | " # cov_fea_tensor = self.emb(features.reshape(batch, time_step, -1))\n", 305 | "\n", 306 | " cov_fea = []\n", 307 | " for i in range(batch):\n", 308 | " lstm = []\n", 309 | " for j in range(time_step):\n", 310 | " fea = features[i, j, :, :]\n", 311 | " g = get_graph(graph[i][j])\n", 312 | " graph_cov = self.graph_embedding(g, fea).view(-1)\n", 313 | " fea_flat = self.flat(graph_cov)\n", 314 | " lstm.append(fea_flat)\n", 315 | " lstm = torch.stack(lstm, dim=0)\n", 316 | " cov_fea.append(lstm)\n", 317 | " cov_fea_tensor = torch.stack(cov_fea, dim=0)\n", 318 | "\n", 319 | " lstm_fea = self.lstm(cov_fea_tensor)\n", 320 | " # lstm_fea = self.ac(lstm_fea)\n", 321 | "\n", 322 | " # altitude_emd = self.altitude_emd(altitude)\n", 323 | " # concat = torch.cat([altitude_emd,lstm_fea], dim=1)\n", 324 | " y_class = self.classify(lstm_fea)\n", 325 | " y_class_softmax = self.softmax(y_class)\n", 326 | " return y_class\n", 327 | "\n", 328 | "\n", 329 | "\n", 330 | "def val_model(model, val_dataloader):\n", 331 | " y_truths, y_preds, val_loss = [], [], []\n", 332 | " with torch.no_grad():\n", 333 | " model.eval()\n", 334 | " for batch in val_dataloader:\n", 335 | " y = batch[2]\n", 336 | " y_pre = model.forward(batch[0], batch[1])\n", 337 | "\n", 338 | " loss = loss_fun(y_pre, y)\n", 339 | " val_loss.append(loss.item())\n", 340 | " pred_i = y_pre.data.max(1, keepdim=True)[1]\n", 341 | " y_truths.append(y.numpy()[0])\n", 342 | " y_preds.append(pred_i.numpy()[0][0])\n", 343 | " val_loss = np.mean(val_loss)\n", 344 | " acc = metrics.accuracy_score(y_truths, y_preds)\n", 345 | " return acc, val_loss\n", 346 | "\n", 347 | "\n", 348 | "def train_model(model, loss_fun, optimizer, train_dataloader, val_dataloader, epochs):\n", 349 | " best_epoch = 1\n", 350 | " for epo in range(epochs):\n", 351 | " losses = []\n", 352 | " for batch in train_dataloader:\n", 353 | " y = batch[2]\n", 354 | " y_dot = model.forward(batch[0], batch[1])\n", 355 | " loss = loss_fun(y_dot, y)\n", 356 | " losses.append(loss.item())\n", 357 | " loss.backward()\n", 358 | " optimizer.step()\n", 359 | " train_loss = np.mean(losses)\n", 360 | " val_acc, val_loss = val_model(model, val_dataloader)\n", 361 | "\n", 362 | " if val_acc < best_epoch:\n", 363 | " best_epoch = val_acc\n", 364 | " torch.save(model.state_dict(), 'epoch_best.pkl')\n", 365 | "\n", 366 | " print('Epochs: %s/%s Train loss: %.6f Val loss: %.6f verification accuracy: %.6f' % (epo, epochs, train_loss, val_loss, val_acc))\n", 367 | "\n", 368 | " print(\"Train finished !\")\n", 369 | " return model\n", 370 | "\n", 371 | "\n", 372 | "def test_model(model, test_dataloader):\n", 373 | " y_truths, y_preds = [], []\n", 374 | " with torch.no_grad():\n", 375 | " model.eval()\n", 376 | " for batch in test_dataloader:\n", 377 | " y = batch[2]\n", 378 | " y_pre = model.forward(batch[0], batch[1])\n", 379 | "\n", 380 | " pred_i = y_pre.data.max(1, keepdim=True)[1]\n", 381 | " y_truths.append(y.numpy()[0])\n", 382 | " y_preds.append(pred_i.numpy()[0][0])\n", 383 | " acc = metrics.accuracy_score(y_truths, y_preds)\n", 384 | " confusion_matrix = metrics.confusion_matrix(y_truths, y_preds)\n", 385 | " print(acc)\n", 386 | " print(confusion_matrix)\n", 387 | " return confusion_matrix\n", 388 | "\n" 389 | ], 390 | "metadata": { 391 | "id": "AS0gBkhUSvZU" 392 | }, 393 | "execution_count": 4, 394 | "outputs": [] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "source": [ 399 | "if __name__ == '__main__':\n", 400 | " path_fea = 'zscore_new_23-08-v2.csv'\n", 401 | " path_label = 'label-v2.csv'\n", 402 | " path_graph = 'spatial_interactive_graphs_v3.npy'\n", 403 | " # path_dem = 'dem.csv'\n", 404 | "\n", 405 | " features, labels, graphs = read_data(path_fea, path_label, path_graph)\n", 406 | " # for i in features:\n", 407 | " # print(i.shape)\n", 408 | " # df_labels=pd.DataFrame(labels)\n", 409 | " # df_labels.plot.density()\n", 410 | " features=stats.zscore(features)\n", 411 | "\n", 412 | " epochs = 100\n", 413 | " batch_size = 128\n", 414 | " lr = 0.0001\n", 415 | "\n", 416 | " train_per = 0.7\n", 417 | " val_per = 0.1\n", 418 | " train_num = int(len(features)*train_per)\n", 419 | " val_num = int(len(features)*val_per)\n", 420 | "\n", 421 | " dataset_train = MyDataset(features[:train_num], labels[:train_num], graphs[:train_num])\n", 422 | " # dataset_val = MyDataset(features[train_num:train_num+val_num], labels[train_num:train_num+val_num], graphs[train_num:train_num+val_num])\n", 423 | " dataset_test = MyDataset(features[train_num:], labels[train_num:], graphs[train_num:])\n", 424 | "\n", 425 | " train_dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)\n", 426 | " val_dataloader = DataLoader(dataset_test)\n", 427 | " test_dataloader = DataLoader(dataset_test)\n", 428 | "\n", 429 | " model = My_model()\n", 430 | " optimizer = torch.optim.Adam(model.parameters(), lr=lr) # config.learning_rate\n", 431 | " loss_fun = nn.CrossEntropyLoss()\n", 432 | "\n", 433 | " train = True\n", 434 | " if train:\n", 435 | " train_model(model, loss_fun, optimizer, train_dataloader, val_dataloader, epochs)\n", 436 | " else:\n", 437 | " model.load_state_dict(torch.load('epoch_best.pkl'))\n", 438 | "\n", 439 | " confusion_matrix = test_model(model, test_dataloader)\n", 440 | "\n", 441 | " # sn.heatmap(confusion_matrix, annot=True)\n", 442 | " # plt.xlabel('Predict')\n", 443 | " # plt.ylabel('Truth')\n", 444 | " # plt.show()" 445 | ], 446 | "metadata": { 447 | "colab": { 448 | "base_uri": "https://localhost:8080/" 449 | }, 450 | "id": "H6tuaxQfnnln", 451 | "outputId": "2feb4702-966d-4df3-9c24-4e26bbb4ddea" 452 | }, 453 | "execution_count": 8, 454 | "outputs": [ 455 | { 456 | "output_type": "stream", 457 | "name": "stdout", 458 | "text": [ 459 | "Epochs: 0/100 Train loss: 1.048282 Val loss: 1.059211 verification accuracy: 0.438889\n", 460 | "Epochs: 1/100 Train loss: 1.043524 Val loss: 1.047480 verification accuracy: 0.490278\n", 461 | "Epochs: 2/100 Train loss: 1.034283 Val loss: 1.007124 verification accuracy: 0.481944\n", 462 | "Epochs: 3/100 Train loss: 1.045569 Val loss: 1.058080 verification accuracy: 0.483333\n", 463 | "Epochs: 4/100 Train loss: 1.062031 Val loss: 1.030557 verification accuracy: 0.476389\n", 464 | "Epochs: 5/100 Train loss: 1.003274 Val loss: 0.988275 verification accuracy: 0.484722\n", 465 | "Epochs: 6/100 Train loss: 1.004822 Val loss: 1.005825 verification accuracy: 0.470833\n", 466 | "Epochs: 7/100 Train loss: 0.998543 Val loss: 0.991265 verification accuracy: 0.500000\n", 467 | "Epochs: 8/100 Train loss: 0.985316 Val loss: 0.984146 verification accuracy: 0.497222\n", 468 | "Epochs: 9/100 Train loss: 0.962275 Val loss: 0.949244 verification accuracy: 0.536111\n", 469 | "Epochs: 10/100 Train loss: 0.971470 Val loss: 0.953669 verification accuracy: 0.508333\n", 470 | "Epochs: 11/100 Train loss: 0.979013 Val loss: 0.935908 verification accuracy: 0.541667\n", 471 | "Epochs: 12/100 Train loss: 0.941083 Val loss: 0.924275 verification accuracy: 0.547222\n", 472 | "Epochs: 13/100 Train loss: 0.946251 Val loss: 0.949300 verification accuracy: 0.523611\n", 473 | "Epochs: 14/100 Train loss: 0.937128 Val loss: 0.976417 verification accuracy: 0.479167\n", 474 | "Epochs: 15/100 Train loss: 0.950255 Val loss: 0.967961 verification accuracy: 0.488889\n", 475 | "Epochs: 16/100 Train loss: 0.940385 Val loss: 0.938127 verification accuracy: 0.519444\n", 476 | "Epochs: 17/100 Train loss: 0.919721 Val loss: 0.923453 verification accuracy: 0.544444\n", 477 | "Epochs: 18/100 Train loss: 0.916439 Val loss: 0.915822 verification accuracy: 0.537500\n", 478 | "Epochs: 19/100 Train loss: 0.916698 Val loss: 0.900947 verification accuracy: 0.550000\n", 479 | "Epochs: 20/100 Train loss: 0.901217 Val loss: 0.886554 verification accuracy: 0.576389\n", 480 | "Epochs: 21/100 Train loss: 0.887568 Val loss: 0.879281 verification accuracy: 0.594444\n", 481 | "Epochs: 22/100 Train loss: 0.901625 Val loss: 0.881562 verification accuracy: 0.583333\n", 482 | "Epochs: 23/100 Train loss: 0.847066 Val loss: 0.882438 verification accuracy: 0.587500\n", 483 | "Epochs: 24/100 Train loss: 0.864573 Val loss: 0.881171 verification accuracy: 0.593056\n", 484 | "Epochs: 25/100 Train loss: 0.846220 Val loss: 0.877841 verification accuracy: 0.587500\n", 485 | "Epochs: 26/100 Train loss: 0.846744 Val loss: 0.874630 verification accuracy: 0.581944\n", 486 | "Epochs: 27/100 Train loss: 0.833987 Val loss: 0.870709 verification accuracy: 0.561111\n", 487 | "Epochs: 28/100 Train loss: 0.818108 Val loss: 0.870221 verification accuracy: 0.545833\n", 488 | "Epochs: 29/100 Train loss: 0.829296 Val loss: 0.867080 verification accuracy: 0.544444\n", 489 | "Epochs: 30/100 Train loss: 0.824163 Val loss: 0.859649 verification accuracy: 0.563889\n", 490 | "Epochs: 31/100 Train loss: 0.780355 Val loss: 0.847469 verification accuracy: 0.573611\n", 491 | "Epochs: 32/100 Train loss: 0.776136 Val loss: 0.835176 verification accuracy: 0.597222\n", 492 | "Epochs: 33/100 Train loss: 0.779187 Val loss: 0.825521 verification accuracy: 0.602778\n", 493 | "Epochs: 34/100 Train loss: 0.749739 Val loss: 0.816702 verification accuracy: 0.598611\n", 494 | "Epochs: 35/100 Train loss: 0.752438 Val loss: 0.809058 verification accuracy: 0.613889\n", 495 | "Epochs: 36/100 Train loss: 0.741759 Val loss: 0.802788 verification accuracy: 0.616667\n", 496 | "Epochs: 37/100 Train loss: 0.724967 Val loss: 0.794025 verification accuracy: 0.623611\n", 497 | "Epochs: 38/100 Train loss: 0.728588 Val loss: 0.786157 verification accuracy: 0.634722\n", 498 | "Epochs: 39/100 Train loss: 0.698120 Val loss: 0.781752 verification accuracy: 0.626389\n", 499 | "Epochs: 40/100 Train loss: 0.697589 Val loss: 0.776357 verification accuracy: 0.629167\n", 500 | "Epochs: 41/100 Train loss: 0.703357 Val loss: 0.773089 verification accuracy: 0.626389\n", 501 | "Epochs: 42/100 Train loss: 0.700344 Val loss: 0.769161 verification accuracy: 0.633333\n", 502 | "Epochs: 43/100 Train loss: 0.665247 Val loss: 0.760302 verification accuracy: 0.637500\n", 503 | "Epochs: 44/100 Train loss: 0.660243 Val loss: 0.752613 verification accuracy: 0.656944\n", 504 | "Epochs: 45/100 Train loss: 0.647325 Val loss: 0.748461 verification accuracy: 0.654167\n", 505 | "Epochs: 46/100 Train loss: 0.642837 Val loss: 0.744346 verification accuracy: 0.647222\n", 506 | "Epochs: 47/100 Train loss: 0.613564 Val loss: 0.736671 verification accuracy: 0.648611\n", 507 | "Epochs: 48/100 Train loss: 0.605683 Val loss: 0.727445 verification accuracy: 0.654167\n", 508 | "Epochs: 49/100 Train loss: 0.606255 Val loss: 0.720826 verification accuracy: 0.662500\n", 509 | "Epochs: 50/100 Train loss: 0.598716 Val loss: 0.718848 verification accuracy: 0.665278\n", 510 | "Epochs: 51/100 Train loss: 0.582294 Val loss: 0.721984 verification accuracy: 0.658333\n", 511 | "Epochs: 52/100 Train loss: 0.578866 Val loss: 0.719750 verification accuracy: 0.658333\n", 512 | "Epochs: 53/100 Train loss: 0.569094 Val loss: 0.711961 verification accuracy: 0.672222\n", 513 | "Epochs: 54/100 Train loss: 0.538340 Val loss: 0.704537 verification accuracy: 0.673611\n", 514 | "Epochs: 55/100 Train loss: 0.533449 Val loss: 0.704208 verification accuracy: 0.681944\n", 515 | "Epochs: 56/100 Train loss: 0.533614 Val loss: 0.700818 verification accuracy: 0.672222\n", 516 | "Epochs: 57/100 Train loss: 0.527320 Val loss: 0.696045 verification accuracy: 0.676389\n", 517 | "Epochs: 58/100 Train loss: 0.526178 Val loss: 0.693515 verification accuracy: 0.675000\n", 518 | "Epochs: 59/100 Train loss: 0.520110 Val loss: 0.692681 verification accuracy: 0.677778\n", 519 | "Epochs: 60/100 Train loss: 0.511868 Val loss: 0.685696 verification accuracy: 0.686111\n", 520 | "Epochs: 61/100 Train loss: 0.489671 Val loss: 0.676475 verification accuracy: 0.700000\n", 521 | "Epochs: 62/100 Train loss: 0.498368 Val loss: 0.668068 verification accuracy: 0.698611\n", 522 | "Epochs: 63/100 Train loss: 0.475682 Val loss: 0.665374 verification accuracy: 0.705556\n", 523 | "Epochs: 64/100 Train loss: 0.490258 Val loss: 0.669618 verification accuracy: 0.694444\n", 524 | "Epochs: 65/100 Train loss: 0.492251 Val loss: 0.665824 verification accuracy: 0.691667\n", 525 | "Epochs: 66/100 Train loss: 0.464553 Val loss: 0.654941 verification accuracy: 0.706944\n", 526 | "Epochs: 67/100 Train loss: 0.470905 Val loss: 0.644226 verification accuracy: 0.713889\n", 527 | "Epochs: 68/100 Train loss: 0.452883 Val loss: 0.643413 verification accuracy: 0.727778\n", 528 | "Epochs: 69/100 Train loss: 0.477583 Val loss: 0.650075 verification accuracy: 0.713889\n", 529 | "Epochs: 70/100 Train loss: 0.460498 Val loss: 0.648872 verification accuracy: 0.718056\n", 530 | "Epochs: 71/100 Train loss: 0.453689 Val loss: 0.635888 verification accuracy: 0.733333\n", 531 | "Epochs: 72/100 Train loss: 0.432183 Val loss: 0.623048 verification accuracy: 0.737500\n", 532 | "Epochs: 73/100 Train loss: 0.429449 Val loss: 0.621902 verification accuracy: 0.733333\n", 533 | "Epochs: 74/100 Train loss: 0.441297 Val loss: 0.631687 verification accuracy: 0.730556\n", 534 | "Epochs: 75/100 Train loss: 0.457833 Val loss: 0.640138 verification accuracy: 0.718056\n", 535 | "Epochs: 76/100 Train loss: 0.445599 Val loss: 0.629981 verification accuracy: 0.716667\n", 536 | "Epochs: 77/100 Train loss: 0.430992 Val loss: 0.604478 verification accuracy: 0.741667\n", 537 | "Epochs: 78/100 Train loss: 0.418427 Val loss: 0.600400 verification accuracy: 0.748611\n", 538 | "Epochs: 79/100 Train loss: 0.405561 Val loss: 0.612069 verification accuracy: 0.747222\n", 539 | "Epochs: 80/100 Train loss: 0.424391 Val loss: 0.634006 verification accuracy: 0.737500\n", 540 | "Epochs: 81/100 Train loss: 0.460735 Val loss: 0.653662 verification accuracy: 0.720833\n", 541 | "Epochs: 82/100 Train loss: 0.445150 Val loss: 0.650375 verification accuracy: 0.716667\n", 542 | "Epochs: 83/100 Train loss: 0.436240 Val loss: 0.634226 verification accuracy: 0.738889\n", 543 | "Epochs: 84/100 Train loss: 0.405173 Val loss: 0.617442 verification accuracy: 0.752778\n", 544 | "Epochs: 85/100 Train loss: 0.402503 Val loss: 0.616155 verification accuracy: 0.752778\n", 545 | "Epochs: 86/100 Train loss: 0.400538 Val loss: 0.629711 verification accuracy: 0.734722\n", 546 | "Epochs: 87/100 Train loss: 0.420454 Val loss: 0.631828 verification accuracy: 0.740278\n", 547 | "Epochs: 88/100 Train loss: 0.413040 Val loss: 0.628637 verification accuracy: 0.748611\n", 548 | "Epochs: 89/100 Train loss: 0.406601 Val loss: 0.618119 verification accuracy: 0.740278\n", 549 | "Epochs: 90/100 Train loss: 0.396085 Val loss: 0.607425 verification accuracy: 0.754167\n", 550 | "Epochs: 91/100 Train loss: 0.380656 Val loss: 0.595041 verification accuracy: 0.756944\n", 551 | "Epochs: 92/100 Train loss: 0.383862 Val loss: 0.605807 verification accuracy: 0.750000\n", 552 | "Epochs: 93/100 Train loss: 0.391439 Val loss: 0.615453 verification accuracy: 0.745833\n", 553 | "Epochs: 94/100 Train loss: 0.393306 Val loss: 0.621385 verification accuracy: 0.740278\n", 554 | "Epochs: 95/100 Train loss: 0.391206 Val loss: 0.625824 verification accuracy: 0.738889\n", 555 | "Epochs: 96/100 Train loss: 0.389325 Val loss: 0.614228 verification accuracy: 0.741667\n", 556 | "Epochs: 97/100 Train loss: 0.381482 Val loss: 0.602450 verification accuracy: 0.765278\n", 557 | "Epochs: 98/100 Train loss: 0.366651 Val loss: 0.586222 verification accuracy: 0.769444\n", 558 | "Epochs: 99/100 Train loss: 0.362413 Val loss: 0.581642 verification accuracy: 0.765278\n", 559 | "Train finished !\n", 560 | "0.7652777777777777\n", 561 | "[[114 53 8]\n", 562 | " [ 13 266 37]\n", 563 | " [ 6 52 171]]\n" 564 | ] 565 | } 566 | ] 567 | } 568 | ] 569 | } 570 | --------------------------------------------------------------------------------