├── fig └── layout.jpg ├── code ├── best_user.pt ├── best_default.pt ├── node_evaluation.py ├── SAGEE.py └── train&test.ipynb ├── dataset └── roomgraph.bin ├── LICENSE └── README.md /fig/layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZijianWang-ZW/SAGE-E/HEAD/fig/layout.jpg -------------------------------------------------------------------------------- /code/best_user.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZijianWang-ZW/SAGE-E/HEAD/code/best_user.pt -------------------------------------------------------------------------------- /code/best_default.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZijianWang-ZW/SAGE-E/HEAD/code/best_default.pt -------------------------------------------------------------------------------- /dataset/roomgraph.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZijianWang-ZW/SAGE-E/HEAD/dataset/roomgraph.bin -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Zijian Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code/node_evaluation.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import dgl 6 | #%%% 7 | 8 | 9 | 10 | # collect single small graphs as a batch 11 | def collate(graphs): 12 | graph = dgl.batch(graphs) 13 | return graph 14 | 15 | 16 | 17 | 18 | def evalEdge(model, nfeat, efeat, subgraph, labels, n_classes): 19 | "This function can be fed with node feature and edge feature, output the prediction of the model" 20 | 21 | with torch.no_grad(): 22 | model.eval() 23 | # subgraph = subgraph.to(device) 24 | 25 | # output the prediction results 26 | logits = model(subgraph, nfeat, efeat) 27 | 28 | # calculate the accuracy 29 | gt = torch.argmax(labels,dim=1) #labels = subgraph.ndata['label'] 30 | pre = torch.argmax(logits, dim=1) 31 | correct = torch.sum(pre == gt) 32 | acc = correct.item()*1.0/len(gt) 33 | 34 | # compute the loss 35 | loss = F.cross_entropy(logits, gt) 36 | 37 | # statistic the correct predictions numbers of each class 38 | one_class_correct, one_class_total = accEachClass(gt, pre, n_classes) 39 | 40 | return acc, loss, one_class_correct, one_class_total, pre, gt 41 | 42 | def accEachClass(gt, pre, n_classes): 43 | one_class_correct = list(0. for i in range(n_classes)) 44 | one_class_total = list(0. for i in range(n_classes)) 45 | # class_acc = list(0. for i in range(n_classes)) 46 | 47 | for i in range(len(gt)): 48 | # for each correct prediction, +1 49 | if gt[i] == pre[i]: 50 | one_class_correct[gt[i]] += 1 51 | 52 | one_class_total[gt[i]] += 1 53 | 54 | return one_class_correct, one_class_total 55 | -------------------------------------------------------------------------------- /code/SAGEE.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import dgl.function as fn 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class SAGEELayer(nn.Module): 9 | def __init__(self, ndim_in, edims, ndim_out, activation): 10 | super(SAGEELayer, self).__init__() 11 | self.W_msg = nn.Linear(ndim_in + edims, ndim_out) 12 | self.W_apply = nn.Linear(ndim_in + ndim_out, ndim_out) 13 | self.activation = activation 14 | 15 | def message_func(self, edges): 16 | return {'m': F.relu(self.W_msg(torch.cat([edges.src['h'], edges.data['h']], 1)))} 17 | 18 | def forward(self, g_dgl, nfeats, efeats): 19 | with g_dgl.local_scope(): 20 | g = g_dgl 21 | g.ndata['h'] = nfeats 22 | g.edata['h'] = efeats 23 | # aggregator function 24 | g.update_all(self.message_func, fn.sum('m', 'h_neigh')) 25 | 26 | # update function 27 | g.ndata['h'] = F.relu(self.W_apply(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))) 28 | return g.ndata['h'] 29 | 30 | 31 | # we adopt a 4 layer SAGE-E model here which owns best performance under this situation 32 | class SAGEE(nn.Module): 33 | def __init__(self, ndim_in, ndim_out, edim, activation, dropout): 34 | super(SAGEE, self).__init__() 35 | self.layers = nn.ModuleList() 36 | self.layers.append(SAGEELayer(ndim_in, edim, 50, activation)) 37 | self.layers.append(SAGEELayer(50, edim, 50, activation)) 38 | self.layers.append(SAGEELayer(50, edim, 25, activation)) 39 | self.layers.append(SAGEELayer(25, edim, ndim_out, activation)) 40 | self.dropout = nn.Dropout(p=dropout) 41 | 42 | def forward(self, g, nfeats, efeats): 43 | for i, layer in enumerate(self.layers): 44 | if i != 0: 45 | nfeats = self.dropout(nfeats) 46 | nfeats = layer(g, nfeats, efeats) 47 | return nfeats 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Applying GNN to BIM graphs for semantic enrichment 2 | 3 | 4 | 5 | We present a novel approach of semantic enrichment, where we represent BIM models as graphs and apply GNNs to BIM graphs for semantic enrichment. 6 | 7 | We select a typical semantic enrichment task -- apartment room type classification -- to test our approach. 8 | 9 | To achieve this goal, we created a BIM graph dataset, named **RoomGraph**, and modified a classic GNN algorithm to leverage both node and edge features, **SAGE-E**. 10 | 11 | The RoomGraph dataset and the source codes of SAGE-E are open to public research use. Enjoy! 12 | 13 | 14 | # Installation 15 | 16 | ## Install Dependencies 17 | 1. Clone or download this repository: 18 | ```bash 19 | git clone https://github.com/ZijianWang-ZW/SAGE-E.git 20 | cd SAGE-E 21 | ``` 22 | 23 | 2. Install [anaconda](https://www.anaconda.com/download) and follow the following steps 24 | 25 | 3. Create a new conda enviroment and install all required packages 26 | 27 | ```bash 28 | conda create -n gnntutorial python=3.8 -y 29 | ``` 30 | 31 | ```bash 32 | conda activate gnntutorial 33 | ``` 34 | 35 | Install the jupyter notebook 36 | ```bash 37 | conda install jupyter notebook -y 38 | ``` 39 | 40 | Install PyTorch. Training and testing SAGE-E does not need special GPU configurations. CPU processing is sufficient for the provided dataset. 41 | ```bash 42 | conda install pytorch=2.3.0 torchvision torchaudio cpuonly -c pytorch 43 | ``` 44 | 45 | ```bash 46 | conda install -c pytorch torchdata 47 | ``` 48 | 49 | ```bash 50 | conda install pydantic -y 51 | ``` 52 | 53 | Install DGL 54 | ```bash 55 | conda install -c dglteam dgl 56 | ``` 57 | 58 | Install all other required libs 59 | ```bash 60 | conda install numpy, pandas, scikit-learn 61 | ``` 62 | 63 | 64 | ## Folder Structure 65 | The following shows the basic folder structure: 66 | ``` 67 | ├── code/ 68 | │ ├── SAGEE.py # The SAGE-E GNN architecture implementation 69 | │ ├── best_default.pt # Pre-trained model weights (default configuration) 70 | │ ├── best_user.pt # Pre-trained model weights (user configuration) 71 | │ ├── node_evaluation.py # Utility functions for model evaluation 72 | │ └── train&test.ipynb # Main training and testing notebook 73 | ├── dataset/ 74 | │ └── roomgraph.bin # RoomGraph dataset (BIM graph data) 75 | ├── fig/ 76 | │ └── layout.jpg # Project illustration 77 | ├── requirements.txt # Python dependencies 78 | ├── README.md # Project documentation 79 | └── LICENSE # MIT License 80 | ``` 81 | 82 | # Usage 83 | 84 | ## Quick Start 85 | 1. **Open the main notebook**: Navigate to `code/train&test.ipynb` 86 | 2. **Run the cells step by step**: The notebook contains detailed explanations for each step 87 | 88 | ## Step-by-Step Guide 89 | 90 | ### 1. Load the Dataset 91 | The RoomGraph dataset is provided in `dataset/roomgraph.bin`. The notebook will automatically load this dataset: 92 | ```python 93 | from dgl.data.utils import load_graphs 94 | bg = load_graphs("../dataset/roomgraph.bin")[0] 95 | ``` 96 | 97 | ### 2. Model Architecture 98 | The SAGE-E model is defined in `code/SAGEE.py`. It consists of: 99 | - **SAGEELayer**: Individual layer that processes both node and edge features 100 | - **SAGEE**: 4-layer network optimized for room type classification 101 | 102 | ### 3. Training 103 | Run the training cells in the notebook to: 104 | - Split data into train/validation/test sets 105 | - Initialize the SAGE-E model 106 | - Train with specified hyperparameters 107 | - Monitor training progress 108 | 109 | ### 4. Evaluation 110 | The notebook includes comprehensive evaluation: 111 | - F1-score calculation 112 | - Confusion matrix generation 113 | - Model performance metrics 114 | 115 | ### 5. Using Pre-trained Models 116 | Two pre-trained models are provided: 117 | - `best_default.pt`: Model with default hyperparameters 118 | - `best_user.pt`: Model with optimized hyperparameters 119 | 120 | Load a pre-trained model: 121 | ```python 122 | model = SAGEE(ndim_in, ndim_out, edim, activation, dropout) 123 | model.load_state_dict(torch.load('best_default.pt')) 124 | ``` 125 | 126 | ## Customization 127 | 128 | ### Using Your Own Data 129 | To use SAGE-E with your own BIM graph data: 130 | 1. Convert your data to DGL graph format 131 | 2. Ensure node and edge features are properly formatted 132 | 3. Modify the data loading section in the notebook 133 | 4. Adjust model parameters if needed 134 | 135 | ### Hyperparameter Tuning 136 | Key hyperparameters you can modify: 137 | - `batch_size`: Batch size for training 138 | - `epochs`: Number of training epochs 139 | - `lr`: Learning rate 140 | - `dropout`: Dropout rate for regularization 141 | - Network architecture (layer dimensions in `SAGEE.py`) 142 | 143 | ## Output 144 | The model performs room type classification on BIM graphs, outputting: 145 | - Predicted room types for each node 146 | - Classification confidence scores 147 | - Performance metrics (F1-score, accuracy, confusion matrix) 148 | 149 | ## Citation 150 | 151 | If you use this code or the RoomGraph dataset in your research, please cite our paper: 152 | 153 | ```bibtex 154 | @article{WANG2022104039, 155 | title = {Exploring graph neural networks for semantic enrichment: Room type classification}, 156 | journal = {Automation in Construction}, 157 | volume = {134}, 158 | pages = {104039}, 159 | year = {2022}, 160 | issn = {0926-5805}, 161 | doi = {https://doi.org/10.1016/j.autcon.2021.104039}, 162 | author = {Zijian Wang and Rafael Sacks and Timson Yeung} 163 | } 164 | ``` 165 | 166 | ## License 167 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 168 | 169 | ## Contact 170 | Welcome to contact Zijian Wang (zijianwang1995@gmail.com) if you have any questions. 171 | 172 | If you want to know more about my work, please visit: https://zijianwang-zw.github.io/ 173 | -------------------------------------------------------------------------------- /code/train&test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## The Jupyter file will guide you to go through an interesting experiment -- using GNN to achieve a BIM node classification task." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "DGL backend not selected or invalid. Assuming PyTorch for now.\n", 20 | "Using backend: pytorch\n" 21 | ] 22 | }, 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Setting the default backend to \"pytorch\". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable. Valid options are: pytorch, mxnet, tensorflow (all lowercase)\n", 28 | "Using device: cpu\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "import os \n", 34 | "# torch and dgl\n", 35 | "import torch\n", 36 | "from torch.utils.data import DataLoader\n", 37 | "import torch.nn.functional as F\n", 38 | "from dgl.data.utils import load_graphs\n", 39 | "# basic machine learning libs\n", 40 | "import numpy as np\n", 41 | "import pandas as pd\n", 42 | "import time\n", 43 | "from sklearn.model_selection import train_test_split\n", 44 | "from sklearn.metrics import f1_score, confusion_matrix\n", 45 | "\n", 46 | "# self construct functions\n", 47 | "from node_evaluation import collate, evalEdge \n", 48 | "from SAGEE import SAGEE\n", 49 | "\n", 50 | "pd.options.mode.chained_assignment = None # default='warn'\n", 51 | "\n", 52 | "device = torch.device('cpu') # CPU is enough for processing small graphs\n", 53 | "print('Using device:', device)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "Set basic parameters. The default runing epoch is 200. You can play with different hyperparameters. " 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 2, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "epochs = 200\n", 70 | "batch_size = 1\n", 71 | "n_classes = 9 # nine room classes here\n", 72 | "weight_decay=5e-4\n", 73 | "num_channels = 50\n", 74 | "lr = 0.005" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### Load RoomGraph dataset.\n", 82 | "\n", 83 | "RoomGraph is a self-designed graph dataset containing 224 apartment layouts collecting from 3 countries. \n", 84 | "\n", 85 | "RoomGraph has 9 different node classes, and each node and edge owns its feature matrix." 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 3, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "train dataset 161, val dataset 18, test dataset 45\n" 98 | ] 99 | } 100 | ], 101 | "source": [ 102 | "bg = load_graphs(\"./../dataset/roomgraph.bin\")[0]\n", 103 | "\n", 104 | "# data split\n", 105 | "trainvalid, test_dataset = train_test_split(bg, test_size=0.2, random_state=42)\n", 106 | "train_dataset, valid_dataset = train_test_split(trainvalid, test_size=0.1, random_state=42)\n", 107 | "\n", 108 | "# data batch for parallel computation\n", 109 | "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate)\n", 110 | "valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate)\n", 111 | "test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate)\n", 112 | "\n", 113 | "print(\"train dataset %i, val dataset %i, test dataset %i\"%(len(train_dataset), \\\n", 114 | " len(valid_dataset), len(test_dataset)))" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "### Load SAGE-E model. \n", 122 | "\n", 123 | "SAGE-E is an improved algorithm based on [GraphSAGE](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf). \n", 124 | "\n", 125 | "The main improvement is that SAGE-E can leverage both node and edge features, but GraphSAGE can only learn from node features." 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 5, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "SAGEE(\n", 138 | " (layers): ModuleList(\n", 139 | " (0): SAGEELayer(\n", 140 | " (W_msg): Linear(in_features=13, out_features=50, bias=True)\n", 141 | " (W_apply): Linear(in_features=58, out_features=50, bias=True)\n", 142 | " )\n", 143 | " (1): SAGEELayer(\n", 144 | " (W_msg): Linear(in_features=55, out_features=50, bias=True)\n", 145 | " (W_apply): Linear(in_features=100, out_features=50, bias=True)\n", 146 | " )\n", 147 | " (2): SAGEELayer(\n", 148 | " (W_msg): Linear(in_features=55, out_features=25, bias=True)\n", 149 | " (W_apply): Linear(in_features=75, out_features=25, bias=True)\n", 150 | " )\n", 151 | " (3): SAGEELayer(\n", 152 | " (W_msg): Linear(in_features=30, out_features=9, bias=True)\n", 153 | " (W_apply): Linear(in_features=34, out_features=9, bias=True)\n", 154 | " )\n", 155 | " )\n", 156 | " (dropout): Dropout(p=0.2, inplace=False)\n", 157 | ")\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "# model loading \n", 163 | "ndim_in = train_dataset[0].ndata['feat'].shape[1]\n", 164 | "edim_in = train_dataset[0].edata['relation'].shape[1]\n", 165 | "\n", 166 | "model = SAGEE(ndim_in, n_classes, edim_in, F.relu, 0.2)\n", 167 | "optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n", 168 | "model = model.to(device)\n", 169 | "print(model)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "### Start to train the model" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 6, 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "name": "stdout", 186 | "output_type": "stream", 187 | "text": [ 188 | "Epoch 001 train | Accuracy: 0.4039 | Loss: 1.7499\n", 189 | "Validation | Accuracy: 0.5497 | Loss: 1.3826\n", 190 | "\n", 191 | "Epoch 002 train | Accuracy: 0.5778 | Loss: 1.3388\n", 192 | "Validation | Accuracy: 0.6221 | Loss: 1.1034\n", 193 | "\n", 194 | "Epoch 003 train | Accuracy: 0.6177 | Loss: 1.1886\n", 195 | "Validation | Accuracy: 0.6360 | Loss: 1.0711\n", 196 | "\n", 197 | "Epoch 004 train | Accuracy: 0.6358 | Loss: 1.1446\n", 198 | "Validation | Accuracy: 0.6771 | Loss: 1.0413\n", 199 | "\n", 200 | "Epoch 005 train | Accuracy: 0.6460 | Loss: 1.1100\n", 201 | "Validation | Accuracy: 0.6735 | Loss: 0.9827\n", 202 | "\n", 203 | "Epoch 006 train | Accuracy: 0.6518 | Loss: 1.0823\n", 204 | "Validation | Accuracy: 0.6702 | Loss: 1.0088\n", 205 | "\n", 206 | "Epoch 007 train | Accuracy: 0.6687 | Loss: 1.0631\n", 207 | "Validation | Accuracy: 0.6760 | Loss: 0.9831\n", 208 | "\n", 209 | "Epoch 008 train | Accuracy: 0.6642 | Loss: 1.0593\n", 210 | "Validation | Accuracy: 0.6752 | Loss: 0.9513\n", 211 | "\n", 212 | "Epoch 009 train | Accuracy: 0.6661 | Loss: 1.0547\n", 213 | "Validation | Accuracy: 0.6895 | Loss: 1.0017\n", 214 | "\n", 215 | "Epoch 010 train | Accuracy: 0.6777 | Loss: 1.0341\n", 216 | "Validation | Accuracy: 0.7218 | Loss: 0.9489\n", 217 | "\n", 218 | "Epoch 011 train | Accuracy: 0.6890 | Loss: 1.0184\n", 219 | "Validation | Accuracy: 0.7151 | Loss: 0.9500\n", 220 | "\n", 221 | "Epoch 012 train | Accuracy: 0.6921 | Loss: 1.0214\n", 222 | "Validation | Accuracy: 0.7520 | Loss: 0.9073\n", 223 | "\n", 224 | "Epoch 013 train | Accuracy: 0.6959 | Loss: 0.9990\n", 225 | "Validation | Accuracy: 0.7409 | Loss: 0.9369\n", 226 | "\n", 227 | "Epoch 014 train | Accuracy: 0.6753 | Loss: 1.0616\n", 228 | "Validation | Accuracy: 0.6455 | Loss: 0.9496\n", 229 | "\n", 230 | "Epoch 015 train | Accuracy: 0.6859 | Loss: 1.0245\n", 231 | "Validation | Accuracy: 0.7131 | Loss: 0.9297\n", 232 | "\n", 233 | "Epoch 016 train | Accuracy: 0.6930 | Loss: 1.0108\n", 234 | "Validation | Accuracy: 0.7383 | Loss: 0.9324\n", 235 | "\n", 236 | "Epoch 017 train | Accuracy: 0.7005 | Loss: 0.9985\n", 237 | "Validation | Accuracy: 0.7311 | Loss: 0.9217\n", 238 | "\n", 239 | "Epoch 018 train | Accuracy: 0.6987 | Loss: 0.9914\n", 240 | "Validation | Accuracy: 0.7519 | Loss: 0.8948\n", 241 | "\n", 242 | "Epoch 019 train | Accuracy: 0.7015 | Loss: 0.9987\n", 243 | "Validation | Accuracy: 0.7505 | Loss: 0.9097\n", 244 | "\n", 245 | "Epoch 020 train | Accuracy: 0.7065 | Loss: 0.9771\n", 246 | "Validation | Accuracy: 0.7575 | Loss: 0.8802\n", 247 | "\n", 248 | "Epoch 021 train | Accuracy: 0.7088 | Loss: 0.9836\n", 249 | "Validation | Accuracy: 0.7418 | Loss: 0.9158\n", 250 | "\n", 251 | "Epoch 022 train | Accuracy: 0.7011 | Loss: 0.9723\n", 252 | "Validation | Accuracy: 0.7621 | Loss: 0.8961\n", 253 | "\n", 254 | "Epoch 023 train | Accuracy: 0.6977 | Loss: 0.9833\n", 255 | "Validation | Accuracy: 0.7637 | Loss: 0.9080\n", 256 | "\n", 257 | "Epoch 024 train | Accuracy: 0.6968 | Loss: 0.9729\n", 258 | "Validation | Accuracy: 0.7411 | Loss: 0.8937\n", 259 | "\n", 260 | "Epoch 025 train | Accuracy: 0.7072 | Loss: 0.9664\n", 261 | "Validation | Accuracy: 0.7560 | Loss: 0.8586\n", 262 | "\n", 263 | "Epoch 026 train | Accuracy: 0.7155 | Loss: 0.9682\n", 264 | "Validation | Accuracy: 0.7621 | Loss: 0.8831\n", 265 | "\n", 266 | "Epoch 027 train | Accuracy: 0.7095 | Loss: 0.9727\n", 267 | "Validation | Accuracy: 0.7632 | Loss: 0.8840\n", 268 | "\n", 269 | "Epoch 028 train | Accuracy: 0.7061 | Loss: 0.9708\n", 270 | "Validation | Accuracy: 0.7513 | Loss: 0.8798\n", 271 | "\n", 272 | "Epoch 029 train | Accuracy: 0.7122 | Loss: 0.9604\n", 273 | "Validation | Accuracy: 0.7440 | Loss: 0.8801\n", 274 | "\n", 275 | "Epoch 030 train | Accuracy: 0.7097 | Loss: 0.9550\n", 276 | "Validation | Accuracy: 0.7563 | Loss: 1.0003\n", 277 | "\n", 278 | "Epoch 031 train | Accuracy: 0.7139 | Loss: 0.9758\n", 279 | "Validation | Accuracy: 0.7531 | Loss: 0.9140\n", 280 | "\n", 281 | "Epoch 032 train | Accuracy: 0.7067 | Loss: 0.9722\n", 282 | "Validation | Accuracy: 0.7637 | Loss: 0.8626\n", 283 | "\n", 284 | "Epoch 033 train | Accuracy: 0.7156 | Loss: 0.9586\n", 285 | "Validation | Accuracy: 0.7637 | Loss: 0.8758\n", 286 | "\n", 287 | "Epoch 034 train | Accuracy: 0.7400 | Loss: 0.8036\n", 288 | "Validation | Accuracy: 0.7818 | Loss: 0.6684\n", 289 | "\n", 290 | "Epoch 035 train | Accuracy: 0.7236 | Loss: 0.7616\n", 291 | "Validation | Accuracy: 0.7656 | Loss: 0.6470\n", 292 | "\n", 293 | "Epoch 036 train | Accuracy: 0.7291 | Loss: 0.7488\n", 294 | "Validation | Accuracy: 0.7575 | Loss: 0.6722\n", 295 | "\n", 296 | "Epoch 037 train | Accuracy: 0.7263 | Loss: 0.7521\n", 297 | "Validation | Accuracy: 0.7660 | Loss: 0.6747\n", 298 | "\n", 299 | "Epoch 038 train | Accuracy: 0.7295 | Loss: 0.7247\n", 300 | "Validation | Accuracy: 0.7785 | Loss: 0.6606\n", 301 | "\n", 302 | "Epoch 039 train | Accuracy: 0.7382 | Loss: 0.7299\n", 303 | "Validation | Accuracy: 0.7644 | Loss: 0.6917\n", 304 | "\n", 305 | "Epoch 040 train | Accuracy: 0.7396 | Loss: 0.7279\n", 306 | "Validation | Accuracy: 0.7660 | Loss: 0.6315\n", 307 | "\n", 308 | "Epoch 041 train | Accuracy: 0.7325 | Loss: 0.7315\n", 309 | "Validation | Accuracy: 0.7559 | Loss: 0.6761\n", 310 | "\n", 311 | "Epoch 042 train | Accuracy: 0.7321 | Loss: 0.7358\n", 312 | "Validation | Accuracy: 0.7396 | Loss: 0.7109\n", 313 | "\n", 314 | "Epoch 043 train | Accuracy: 0.7368 | Loss: 0.7290\n", 315 | "Validation | Accuracy: 0.7660 | Loss: 0.6499\n", 316 | "\n", 317 | "Epoch 044 train | Accuracy: 0.7380 | Loss: 0.7109\n", 318 | "Validation | Accuracy: 0.7532 | Loss: 0.6670\n", 319 | "\n", 320 | "Epoch 045 train | Accuracy: 0.7202 | Loss: 0.7453\n", 321 | "Validation | Accuracy: 0.7772 | Loss: 0.6693\n", 322 | "\n", 323 | "Epoch 046 train | Accuracy: 0.7254 | Loss: 0.7297\n", 324 | "Validation | Accuracy: 0.7695 | Loss: 0.6841\n", 325 | "\n", 326 | "Epoch 047 train | Accuracy: 0.7426 | Loss: 0.6843\n", 327 | "Validation | Accuracy: 0.7671 | Loss: 0.6631\n", 328 | "\n", 329 | "Epoch 048 train | Accuracy: 0.7558 | Loss: 0.6280\n", 330 | "Validation | Accuracy: 0.7940 | Loss: 0.5870\n", 331 | "\n", 332 | "Epoch 049 train | Accuracy: 0.7952 | Loss: 0.5270\n", 333 | "Validation | Accuracy: 0.8160 | Loss: 0.5373\n", 334 | "\n", 335 | "Epoch 050 train | Accuracy: 0.7888 | Loss: 0.5146\n", 336 | "Validation | Accuracy: 0.7991 | Loss: 0.5687\n", 337 | "\n", 338 | "Epoch 051 train | Accuracy: 0.7889 | Loss: 0.4979\n", 339 | "Validation | Accuracy: 0.8385 | Loss: 0.5402\n", 340 | "\n", 341 | "Epoch 052 train | Accuracy: 0.7949 | Loss: 0.4921\n", 342 | "Validation | Accuracy: 0.8165 | Loss: 0.5323\n", 343 | "\n", 344 | "Epoch 053 train | Accuracy: 0.8030 | Loss: 0.5061\n", 345 | "Validation | Accuracy: 0.8261 | Loss: 0.5269\n", 346 | "\n", 347 | "Epoch 054 train | Accuracy: 0.7964 | Loss: 0.4944\n", 348 | "Validation | Accuracy: 0.8080 | Loss: 0.5172\n", 349 | "\n", 350 | "Epoch 055 train | Accuracy: 0.8111 | Loss: 0.4711\n", 351 | "Validation | Accuracy: 0.7925 | Loss: 0.5837\n", 352 | "\n", 353 | "Epoch 056 train | Accuracy: 0.7933 | Loss: 0.4973\n", 354 | "Validation | Accuracy: 0.8129 | Loss: 0.5434\n", 355 | "\n", 356 | "Epoch 057 train | Accuracy: 0.7929 | Loss: 0.4881\n", 357 | "Validation | Accuracy: 0.8235 | Loss: 0.5137\n", 358 | "\n", 359 | "Epoch 058 train | Accuracy: 0.7911 | Loss: 0.5114\n", 360 | "Validation | Accuracy: 0.8156 | Loss: 0.5244\n", 361 | "\n", 362 | "Epoch 059 train | Accuracy: 0.7859 | Loss: 0.5567\n", 363 | "Validation | Accuracy: 0.7969 | Loss: 0.5607\n", 364 | "\n", 365 | "Epoch 060 train | Accuracy: 0.7789 | Loss: 0.5163\n", 366 | "Validation | Accuracy: 0.8122 | Loss: 0.5188\n", 367 | "\n", 368 | "Epoch 061 train | Accuracy: 0.8033 | Loss: 0.4811\n", 369 | "Validation | Accuracy: 0.8110 | Loss: 0.5580\n", 370 | "\n", 371 | "Epoch 062 train | Accuracy: 0.7868 | Loss: 0.5121\n", 372 | "Validation | Accuracy: 0.8303 | Loss: 0.5065\n", 373 | "\n", 374 | "Epoch 063 train | Accuracy: 0.8043 | Loss: 0.4958\n", 375 | "Validation | Accuracy: 0.8261 | Loss: 0.5020\n", 376 | "\n", 377 | "Epoch 064 train | Accuracy: 0.8136 | Loss: 0.4574\n", 378 | "Validation | Accuracy: 0.8207 | Loss: 0.5315\n", 379 | "\n", 380 | "Epoch 065 train | Accuracy: 0.8121 | Loss: 0.4612\n", 381 | "Validation | Accuracy: 0.8160 | Loss: 0.5572\n", 382 | "\n", 383 | "Epoch 066 train | Accuracy: 0.8076 | Loss: 0.4786\n", 384 | "Validation | Accuracy: 0.8515 | Loss: 0.4885\n", 385 | "\n", 386 | "Epoch 067 train | Accuracy: 0.8007 | Loss: 0.4980\n", 387 | "Validation | Accuracy: 0.8246 | Loss: 0.5288\n", 388 | "\n", 389 | "Epoch 068 train | Accuracy: 0.8081 | Loss: 0.4669\n", 390 | "Validation | Accuracy: 0.8292 | Loss: 0.5461\n", 391 | "\n", 392 | "Epoch 069 train | Accuracy: 0.7999 | Loss: 0.4852\n", 393 | "Validation | Accuracy: 0.8275 | Loss: 0.5215\n", 394 | "\n", 395 | "Epoch 070 train | Accuracy: 0.8099 | Loss: 0.4691\n", 396 | "Validation | Accuracy: 0.7779 | Loss: 0.5756\n", 397 | "\n", 398 | "Epoch 071 train | Accuracy: 0.8070 | Loss: 0.4655\n", 399 | "Validation | Accuracy: 0.8226 | Loss: 0.5599\n", 400 | "\n", 401 | "Epoch 072 train | Accuracy: 0.8094 | Loss: 0.4665\n", 402 | "Validation | Accuracy: 0.8308 | Loss: 0.5148\n", 403 | "\n", 404 | "Epoch 073 train | Accuracy: 0.8061 | Loss: 0.4894\n", 405 | "Validation | Accuracy: 0.7961 | Loss: 0.6434\n", 406 | "\n", 407 | "Epoch 074 train | Accuracy: 0.8058 | Loss: 0.4604\n", 408 | "Validation | Accuracy: 0.8341 | Loss: 0.5299\n", 409 | "\n", 410 | "Epoch 075 train | Accuracy: 0.8040 | Loss: 0.4874\n", 411 | "Validation | Accuracy: 0.7764 | Loss: 0.5851\n", 412 | "\n", 413 | "Epoch 076 train | Accuracy: 0.8002 | Loss: 0.5061\n", 414 | "Validation | Accuracy: 0.8515 | Loss: 0.5216\n", 415 | "\n", 416 | "Epoch 077 train | Accuracy: 0.8084 | Loss: 0.4674\n", 417 | "Validation | Accuracy: 0.8462 | Loss: 0.5192\n", 418 | "\n", 419 | "Epoch 078 train | Accuracy: 0.8112 | Loss: 0.4691\n", 420 | "Validation | Accuracy: 0.8449 | Loss: 0.5057\n", 421 | "\n", 422 | "Epoch 079 train | Accuracy: 0.7997 | Loss: 0.4954\n", 423 | "Validation | Accuracy: 0.8490 | Loss: 0.4774\n", 424 | "\n", 425 | "Epoch 080 train | Accuracy: 0.8097 | Loss: 0.4641\n", 426 | "Validation | Accuracy: 0.8383 | Loss: 0.4884\n", 427 | "\n", 428 | "Epoch 081 train | Accuracy: 0.8013 | Loss: 0.4886\n", 429 | "Validation | Accuracy: 0.8398 | Loss: 0.5008\n", 430 | "\n", 431 | "Epoch 082 train | Accuracy: 0.8162 | Loss: 0.4538\n", 432 | "Validation | Accuracy: 0.8374 | Loss: 0.4985\n", 433 | "\n", 434 | "Epoch 083 train | Accuracy: 0.8150 | Loss: 0.4662\n", 435 | "Validation | Accuracy: 0.8235 | Loss: 0.5109\n", 436 | "\n", 437 | "Epoch 084 train | Accuracy: 0.8168 | Loss: 0.4569\n", 438 | "Validation | Accuracy: 0.8152 | Loss: 0.5018\n", 439 | "\n", 440 | "Epoch 085 train | Accuracy: 0.8207 | Loss: 0.4506\n", 441 | "Validation | Accuracy: 0.8264 | Loss: 0.5095\n", 442 | "\n", 443 | "Epoch 086 train | Accuracy: 0.8152 | Loss: 0.4405\n", 444 | "Validation | Accuracy: 0.8321 | Loss: 0.5205\n", 445 | "\n", 446 | "Epoch 087 train | Accuracy: 0.8023 | Loss: 0.4770\n", 447 | "Validation | Accuracy: 0.8253 | Loss: 0.4998\n", 448 | "\n", 449 | "Epoch 088 train | Accuracy: 0.8056 | Loss: 0.4793\n", 450 | "Validation | Accuracy: 0.8325 | Loss: 0.5118\n", 451 | "\n", 452 | "Epoch 089 train | Accuracy: 0.8115 | Loss: 0.4669\n", 453 | "Validation | Accuracy: 0.8341 | Loss: 0.5404\n", 454 | "\n", 455 | "Epoch 090 train | Accuracy: 0.8112 | Loss: 0.4691\n", 456 | "Validation | Accuracy: 0.8217 | Loss: 0.4803\n", 457 | "\n", 458 | "Epoch 091 train | Accuracy: 0.7910 | Loss: 0.5084\n", 459 | "Validation | Accuracy: 0.8066 | Loss: 0.5410\n", 460 | "\n", 461 | "Epoch 092 train | Accuracy: 0.8039 | Loss: 0.4557\n", 462 | "Validation | Accuracy: 0.8233 | Loss: 0.4833\n", 463 | "\n", 464 | "Epoch 093 train | Accuracy: 0.8252 | Loss: 0.4347\n", 465 | "Validation | Accuracy: 0.8070 | Loss: 0.5120\n", 466 | "\n", 467 | "Epoch 094 train | Accuracy: 0.8258 | Loss: 0.4382\n", 468 | "Validation | Accuracy: 0.8191 | Loss: 0.4966\n", 469 | "\n", 470 | "Epoch 095 train | Accuracy: 0.8060 | Loss: 0.4583\n", 471 | "Validation | Accuracy: 0.8301 | Loss: 0.4990\n", 472 | "\n", 473 | "Epoch 096 train | Accuracy: 0.8153 | Loss: 0.4519\n", 474 | "Validation | Accuracy: 0.8019 | Loss: 0.5176\n", 475 | "\n", 476 | "Epoch 097 train | Accuracy: 0.8229 | Loss: 0.4509\n", 477 | "Validation | Accuracy: 0.8193 | Loss: 0.5418\n", 478 | "\n", 479 | "Epoch 098 train | Accuracy: 0.8149 | Loss: 0.4600\n", 480 | "Validation | Accuracy: 0.8099 | Loss: 0.4866\n", 481 | "\n", 482 | "Epoch 099 train | Accuracy: 0.8142 | Loss: 0.4493\n", 483 | "Validation | Accuracy: 0.8363 | Loss: 0.4676\n", 484 | "\n", 485 | "Epoch 100 train | Accuracy: 0.8199 | Loss: 0.4343\n", 486 | "Validation | Accuracy: 0.8352 | Loss: 0.4537\n", 487 | "\n", 488 | "Epoch 101 train | Accuracy: 0.8193 | Loss: 0.4554\n", 489 | "Validation | Accuracy: 0.8433 | Loss: 0.4547\n", 490 | "\n", 491 | "Epoch 102 train | Accuracy: 0.8118 | Loss: 0.4428\n", 492 | "Validation | Accuracy: 0.8209 | Loss: 0.4599\n", 493 | "\n", 494 | "Epoch 103 train | Accuracy: 0.8151 | Loss: 0.4580\n", 495 | "Validation | Accuracy: 0.8125 | Loss: 0.5437\n", 496 | "\n", 497 | "Epoch 104 train | Accuracy: 0.8183 | Loss: 0.4504\n", 498 | "Validation | Accuracy: 0.8286 | Loss: 0.5081\n", 499 | "\n", 500 | "Epoch 105 train | Accuracy: 0.8212 | Loss: 0.4355\n", 501 | "Validation | Accuracy: 0.8290 | Loss: 0.4858\n", 502 | "\n", 503 | "Epoch 106 train | Accuracy: 0.8092 | Loss: 0.4422\n", 504 | "Validation | Accuracy: 0.8254 | Loss: 0.5441\n", 505 | "\n", 506 | "Epoch 107 train | Accuracy: 0.8204 | Loss: 0.4480\n", 507 | "Validation | Accuracy: 0.8099 | Loss: 0.5122\n", 508 | "\n", 509 | "Epoch 108 train | Accuracy: 0.8099 | Loss: 0.4504\n", 510 | "Validation | Accuracy: 0.8028 | Loss: 0.4674\n", 511 | "\n", 512 | "Epoch 109 train | Accuracy: 0.8306 | Loss: 0.4257\n", 513 | "Validation | Accuracy: 0.8421 | Loss: 0.4792\n", 514 | "\n", 515 | "Epoch 110 train | Accuracy: 0.8188 | Loss: 0.4490\n", 516 | "Validation | Accuracy: 0.8506 | Loss: 0.4727\n", 517 | "\n", 518 | "Epoch 111 train | Accuracy: 0.8196 | Loss: 0.4293\n", 519 | "Validation | Accuracy: 0.8348 | Loss: 0.5008\n", 520 | "\n", 521 | "Epoch 112 train | Accuracy: 0.8192 | Loss: 0.4438\n", 522 | "Validation | Accuracy: 0.8438 | Loss: 0.5139\n", 523 | "\n", 524 | "Epoch 113 train | Accuracy: 0.8191 | Loss: 0.4467\n", 525 | "Validation | Accuracy: 0.8196 | Loss: 0.5168\n", 526 | "\n", 527 | "Epoch 114 train | Accuracy: 0.8142 | Loss: 0.4659\n", 528 | "Validation | Accuracy: 0.8099 | Loss: 0.5942\n", 529 | "\n", 530 | "Epoch 115 train | Accuracy: 0.8018 | Loss: 0.5197\n", 531 | "Validation | Accuracy: 0.8212 | Loss: 0.5627\n", 532 | "\n", 533 | "Epoch 116 train | Accuracy: 0.8124 | Loss: 0.4615\n", 534 | "Validation | Accuracy: 0.8240 | Loss: 0.5261\n", 535 | "\n", 536 | "Epoch 117 train | Accuracy: 0.8304 | Loss: 0.4278\n", 537 | "Validation | Accuracy: 0.8471 | Loss: 0.4819\n", 538 | "\n", 539 | "Epoch 118 train | Accuracy: 0.8132 | Loss: 0.4222\n", 540 | "Validation | Accuracy: 0.8654 | Loss: 0.4488\n", 541 | "\n", 542 | "Epoch 119 train | Accuracy: 0.8155 | Loss: 0.4442\n", 543 | "Validation | Accuracy: 0.8502 | Loss: 0.4819\n", 544 | "\n", 545 | "Epoch 120 train | Accuracy: 0.8265 | Loss: 0.4256\n", 546 | "Validation | Accuracy: 0.8398 | Loss: 0.5046\n", 547 | "\n", 548 | "Epoch 121 train | Accuracy: 0.8157 | Loss: 0.4365\n", 549 | "Validation | Accuracy: 0.8230 | Loss: 0.4523\n", 550 | "\n", 551 | "Epoch 122 train | Accuracy: 0.8277 | Loss: 0.4147\n", 552 | "Validation | Accuracy: 0.8178 | Loss: 0.5228\n", 553 | "\n", 554 | "Epoch 123 train | Accuracy: 0.8142 | Loss: 0.4558\n", 555 | "Validation | Accuracy: 0.8303 | Loss: 0.4890\n", 556 | "\n", 557 | "Epoch 124 train | Accuracy: 0.8269 | Loss: 0.4221\n", 558 | "Validation | Accuracy: 0.8321 | Loss: 0.4943\n", 559 | "\n", 560 | "Epoch 125 train | Accuracy: 0.8202 | Loss: 0.4452\n", 561 | "Validation | Accuracy: 0.8271 | Loss: 0.5191\n", 562 | "\n", 563 | "Epoch 126 train | Accuracy: 0.8301 | Loss: 0.4146\n", 564 | "Validation | Accuracy: 0.8418 | Loss: 0.4629\n", 565 | "\n", 566 | "Epoch 127 train | Accuracy: 0.8218 | Loss: 0.4317\n", 567 | "Validation | Accuracy: 0.8315 | Loss: 0.4728\n", 568 | "\n", 569 | "Epoch 128 train | Accuracy: 0.8304 | Loss: 0.4313\n", 570 | "Validation | Accuracy: 0.8492 | Loss: 0.4740\n", 571 | "\n", 572 | "Epoch 129 train | Accuracy: 0.8275 | Loss: 0.4461\n", 573 | "Validation | Accuracy: 0.8282 | Loss: 0.5227\n", 574 | "\n", 575 | "Epoch 130 train | Accuracy: 0.8245 | Loss: 0.4321\n", 576 | "Validation | Accuracy: 0.8204 | Loss: 0.4933\n", 577 | "\n", 578 | "Epoch 131 train | Accuracy: 0.8321 | Loss: 0.4165\n", 579 | "Validation | Accuracy: 0.8488 | Loss: 0.4675\n", 580 | "\n", 581 | "Epoch 132 train | Accuracy: 0.8287 | Loss: 0.4367\n", 582 | "Validation | Accuracy: 0.8464 | Loss: 0.4914\n", 583 | "\n", 584 | "Epoch 133 train | Accuracy: 0.8187 | Loss: 0.4377\n", 585 | "Validation | Accuracy: 0.8426 | Loss: 0.4860\n", 586 | "\n", 587 | "Epoch 134 train | Accuracy: 0.8244 | Loss: 0.4182\n", 588 | "Validation | Accuracy: 0.8442 | Loss: 0.5058\n", 589 | "\n", 590 | "Epoch 135 train | Accuracy: 0.8122 | Loss: 0.4466\n", 591 | "Validation | Accuracy: 0.8152 | Loss: 0.5541\n", 592 | "\n", 593 | "Epoch 136 train | Accuracy: 0.8088 | Loss: 0.4527\n", 594 | "Validation | Accuracy: 0.8456 | Loss: 0.4604\n", 595 | "\n", 596 | "Epoch 137 train | Accuracy: 0.8147 | Loss: 0.4472\n", 597 | "Validation | Accuracy: 0.8541 | Loss: 0.4809\n", 598 | "\n", 599 | "Epoch 138 train | Accuracy: 0.8319 | Loss: 0.4033\n", 600 | "Validation | Accuracy: 0.8445 | Loss: 0.5223\n", 601 | "\n", 602 | "Epoch 139 train | Accuracy: 0.8319 | Loss: 0.4307\n", 603 | "Validation | Accuracy: 0.8499 | Loss: 0.4868\n", 604 | "\n", 605 | "Epoch 140 train | Accuracy: 0.8324 | Loss: 0.4121\n", 606 | "Validation | Accuracy: 0.8000 | Loss: 0.5511\n", 607 | "\n", 608 | "Epoch 141 train | Accuracy: 0.8203 | Loss: 0.4265\n", 609 | "Validation | Accuracy: 0.8106 | Loss: 0.5179\n", 610 | "\n", 611 | "Epoch 142 train | Accuracy: 0.8287 | Loss: 0.4160\n", 612 | "Validation | Accuracy: 0.8530 | Loss: 0.4740\n", 613 | "\n", 614 | "Epoch 143 train | Accuracy: 0.8273 | Loss: 0.4274\n", 615 | "Validation | Accuracy: 0.8372 | Loss: 0.4843\n", 616 | "\n", 617 | "Epoch 144 train | Accuracy: 0.8260 | Loss: 0.4216\n", 618 | "Validation | Accuracy: 0.8480 | Loss: 0.4679\n", 619 | "\n", 620 | "Epoch 145 train | Accuracy: 0.8239 | Loss: 0.4139\n", 621 | "Validation | Accuracy: 0.7969 | Loss: 0.6099\n", 622 | "\n", 623 | "Epoch 146 train | Accuracy: 0.8160 | Loss: 0.4398\n", 624 | "Validation | Accuracy: 0.8350 | Loss: 0.5187\n", 625 | "\n", 626 | "Epoch 147 train | Accuracy: 0.8282 | Loss: 0.4429\n", 627 | "Validation | Accuracy: 0.8349 | Loss: 0.4710\n", 628 | "\n", 629 | "Epoch 148 train | Accuracy: 0.8366 | Loss: 0.4240\n", 630 | "Validation | Accuracy: 0.8407 | Loss: 0.5280\n", 631 | "\n", 632 | "Epoch 149 train | Accuracy: 0.8187 | Loss: 0.4306\n", 633 | "Validation | Accuracy: 0.8271 | Loss: 0.5176\n", 634 | "\n", 635 | "Epoch 150 train | Accuracy: 0.8245 | Loss: 0.4090\n", 636 | "Validation | Accuracy: 0.8310 | Loss: 0.4565\n", 637 | "\n", 638 | "Epoch 151 train | Accuracy: 0.8361 | Loss: 0.4129\n", 639 | "Validation | Accuracy: 0.8314 | Loss: 0.4877\n", 640 | "\n", 641 | "Epoch 152 train | Accuracy: 0.8428 | Loss: 0.3845\n", 642 | "Validation | Accuracy: 0.8282 | Loss: 0.6289\n", 643 | "\n", 644 | "Epoch 153 train | Accuracy: 0.8210 | Loss: 0.4130\n", 645 | "Validation | Accuracy: 0.8213 | Loss: 0.5477\n", 646 | "\n", 647 | "Epoch 154 train | Accuracy: 0.8302 | Loss: 0.4110\n", 648 | "Validation | Accuracy: 0.8179 | Loss: 0.5131\n", 649 | "\n", 650 | "Epoch 155 train | Accuracy: 0.8318 | Loss: 0.4209\n", 651 | "Validation | Accuracy: 0.8336 | Loss: 0.5208\n", 652 | "\n", 653 | "Epoch 156 train | Accuracy: 0.8280 | Loss: 0.4075\n", 654 | "Validation | Accuracy: 0.8282 | Loss: 0.5835\n", 655 | "\n", 656 | "Epoch 157 train | Accuracy: 0.8234 | Loss: 0.4249\n", 657 | "Validation | Accuracy: 0.8301 | Loss: 0.4964\n", 658 | "\n", 659 | "Epoch 158 train | Accuracy: 0.8221 | Loss: 0.4098\n", 660 | "Validation | Accuracy: 0.8171 | Loss: 0.5177\n", 661 | "\n", 662 | "Epoch 159 train | Accuracy: 0.8337 | Loss: 0.4050\n", 663 | "Validation | Accuracy: 0.8217 | Loss: 0.5299\n", 664 | "\n", 665 | "Epoch 160 train | Accuracy: 0.8320 | Loss: 0.3948\n", 666 | "Validation | Accuracy: 0.8034 | Loss: 0.5483\n", 667 | "\n", 668 | "Epoch 161 train | Accuracy: 0.8259 | Loss: 0.4206\n", 669 | "Validation | Accuracy: 0.8101 | Loss: 0.5308\n", 670 | "\n", 671 | "Epoch 162 train | Accuracy: 0.8266 | Loss: 0.4124\n", 672 | "Validation | Accuracy: 0.8119 | Loss: 0.5668\n", 673 | "\n", 674 | "Epoch 163 train | Accuracy: 0.8329 | Loss: 0.4265\n", 675 | "Validation | Accuracy: 0.8050 | Loss: 0.5457\n", 676 | "\n", 677 | "Epoch 164 train | Accuracy: 0.8222 | Loss: 0.4222\n", 678 | "Validation | Accuracy: 0.8147 | Loss: 0.5590\n", 679 | "\n", 680 | "Epoch 165 train | Accuracy: 0.8284 | Loss: 0.4057\n", 681 | "Validation | Accuracy: 0.8212 | Loss: 0.5220\n", 682 | "\n", 683 | "Epoch 166 train | Accuracy: 0.8336 | Loss: 0.4089\n", 684 | "Validation | Accuracy: 0.8123 | Loss: 0.5905\n", 685 | "\n", 686 | "Epoch 167 train | Accuracy: 0.8276 | Loss: 0.4125\n", 687 | "Validation | Accuracy: 0.8314 | Loss: 0.5342\n", 688 | "\n", 689 | "Epoch 168 train | Accuracy: 0.8386 | Loss: 0.4022\n", 690 | "Validation | Accuracy: 0.7993 | Loss: 0.5722\n", 691 | "\n", 692 | "Epoch 169 train | Accuracy: 0.8199 | Loss: 0.4368\n", 693 | "Validation | Accuracy: 0.7737 | Loss: 0.5683\n", 694 | "\n", 695 | "Epoch 170 train | Accuracy: 0.8254 | Loss: 0.4159\n", 696 | "Validation | Accuracy: 0.8010 | Loss: 0.5215\n", 697 | "\n", 698 | "Epoch 171 train | Accuracy: 0.8174 | Loss: 0.4304\n", 699 | "Validation | Accuracy: 0.8230 | Loss: 0.4990\n", 700 | "\n", 701 | "Epoch 172 train | Accuracy: 0.8334 | Loss: 0.3965\n", 702 | "Validation | Accuracy: 0.8096 | Loss: 0.5055\n", 703 | "\n", 704 | "Epoch 173 train | Accuracy: 0.8192 | Loss: 0.4260\n", 705 | "Validation | Accuracy: 0.7982 | Loss: 0.5523\n", 706 | "\n", 707 | "Epoch 174 train | Accuracy: 0.8394 | Loss: 0.3984\n", 708 | "Validation | Accuracy: 0.8221 | Loss: 0.5493\n", 709 | "\n", 710 | "Epoch 175 train | Accuracy: 0.8224 | Loss: 0.4162\n", 711 | "Validation | Accuracy: 0.8325 | Loss: 0.5477\n", 712 | "\n", 713 | "Epoch 176 train | Accuracy: 0.8464 | Loss: 0.3886\n", 714 | "Validation | Accuracy: 0.8198 | Loss: 0.4955\n", 715 | "\n", 716 | "Epoch 177 train | Accuracy: 0.8335 | Loss: 0.4129\n", 717 | "Validation | Accuracy: 0.8016 | Loss: 0.6229\n", 718 | "\n", 719 | "Epoch 178 train | Accuracy: 0.8266 | Loss: 0.4312\n", 720 | "Validation | Accuracy: 0.7843 | Loss: 0.5803\n", 721 | "\n", 722 | "Epoch 179 train | Accuracy: 0.8345 | Loss: 0.4151\n", 723 | "Validation | Accuracy: 0.8341 | Loss: 0.4561\n", 724 | "\n", 725 | "Epoch 180 train | Accuracy: 0.8258 | Loss: 0.4274\n", 726 | "Validation | Accuracy: 0.8115 | Loss: 0.6246\n", 727 | "\n", 728 | "Epoch 181 train | Accuracy: 0.8120 | Loss: 0.4451\n", 729 | "Validation | Accuracy: 0.8209 | Loss: 0.5661\n", 730 | "\n", 731 | "Epoch 182 train | Accuracy: 0.8430 | Loss: 0.3981\n", 732 | "Validation | Accuracy: 0.7984 | Loss: 0.5292\n", 733 | "\n", 734 | "Epoch 183 train | Accuracy: 0.8233 | Loss: 0.4093\n", 735 | "Validation | Accuracy: 0.7516 | Loss: 0.5966\n", 736 | "\n", 737 | "Epoch 184 train | Accuracy: 0.8280 | Loss: 0.4203\n", 738 | "Validation | Accuracy: 0.8055 | Loss: 0.5800\n", 739 | "\n", 740 | "Epoch 185 train | Accuracy: 0.8271 | Loss: 0.4399\n", 741 | "Validation | Accuracy: 0.7920 | Loss: 0.5809\n", 742 | "\n", 743 | "Epoch 186 train | Accuracy: 0.8366 | Loss: 0.4079\n", 744 | "Validation | Accuracy: 0.7867 | Loss: 0.5190\n", 745 | "\n", 746 | "Epoch 187 train | Accuracy: 0.8230 | Loss: 0.4540\n", 747 | "Validation | Accuracy: 0.8268 | Loss: 0.4903\n", 748 | "\n", 749 | "Epoch 188 train | Accuracy: 0.8364 | Loss: 0.4072\n", 750 | "Validation | Accuracy: 0.7905 | Loss: 0.5433\n", 751 | "\n", 752 | "Epoch 189 train | Accuracy: 0.8409 | Loss: 0.3807\n", 753 | "Validation | Accuracy: 0.8181 | Loss: 0.5228\n", 754 | "\n", 755 | "Epoch 190 train | Accuracy: 0.8407 | Loss: 0.4064\n", 756 | "Validation | Accuracy: 0.7826 | Loss: 0.5520\n", 757 | "\n", 758 | "Epoch 191 train | Accuracy: 0.8243 | Loss: 0.4237\n", 759 | "Validation | Accuracy: 0.8292 | Loss: 0.5378\n", 760 | "\n", 761 | "Epoch 192 train | Accuracy: 0.8455 | Loss: 0.3978\n", 762 | "Validation | Accuracy: 0.8290 | Loss: 0.5414\n", 763 | "\n", 764 | "Epoch 193 train | Accuracy: 0.8393 | Loss: 0.4174\n", 765 | "Validation | Accuracy: 0.8200 | Loss: 0.5017\n", 766 | "\n", 767 | "Epoch 194 train | Accuracy: 0.8283 | Loss: 0.4090\n", 768 | "Validation | Accuracy: 0.8306 | Loss: 0.4988\n", 769 | "\n", 770 | "Epoch 195 train | Accuracy: 0.8308 | Loss: 0.4160\n", 771 | "Validation | Accuracy: 0.8188 | Loss: 0.4809\n", 772 | "\n", 773 | "Epoch 196 train | Accuracy: 0.8334 | Loss: 0.4216\n", 774 | "Validation | Accuracy: 0.8325 | Loss: 0.5507\n", 775 | "\n", 776 | "Epoch 197 train | Accuracy: 0.8348 | Loss: 0.4000\n", 777 | "Validation | Accuracy: 0.7998 | Loss: 0.6347\n", 778 | "\n", 779 | "Epoch 198 train | Accuracy: 0.8315 | Loss: 0.4025\n", 780 | "Validation | Accuracy: 0.8514 | Loss: 0.5078\n", 781 | "\n", 782 | "Epoch 199 train | Accuracy: 0.8325 | Loss: 0.4093\n", 783 | "Validation | Accuracy: 0.8312 | Loss: 0.4781\n", 784 | "\n", 785 | "Epoch 200 train | Accuracy: 0.8484 | Loss: 0.3748\n", 786 | "Validation | Accuracy: 0.8362 | Loss: 0.4827\n", 787 | "\n", 788 | "Finish training! Using 391.5677 s:\n" 789 | ] 790 | } 791 | ], 792 | "source": [ 793 | "train_acc_all, train_loss_all = [], []\n", 794 | "val_acc_all, val_loss_all = [], []\n", 795 | "\n", 796 | "\n", 797 | "train_startime = time.time()\n", 798 | "\n", 799 | "for epoch in range(epochs):\n", 800 | "\n", 801 | " #### train one epoch \n", 802 | " model.train()\n", 803 | "\n", 804 | " train_acc_list = []\n", 805 | " train_loss_list = []\n", 806 | "\n", 807 | " # feed graph to algorithm one by one\n", 808 | " for batch, subgraph in enumerate(train_dataloader):\n", 809 | "\n", 810 | " subgraph = subgraph.to(device) \n", 811 | " nfeat = subgraph.ndata['feat'].float()\n", 812 | " efeat = subgraph.edata['relation'].float()\n", 813 | "\n", 814 | " logits = model(subgraph, nfeat, efeat) # get the prediction from models \n", 815 | "\n", 816 | " # calculate the accuracy \n", 817 | " gt = torch.argmax(subgraph.ndata['label'], dim=1) # ground true labels\n", 818 | " pre = torch.argmax(logits, dim=1) # prediction labels \n", 819 | " correct = torch.sum(pre == gt) # calculate the right labels \n", 820 | "\n", 821 | " acc = correct.item()*1.0/len(gt) # calculate the accuracy \n", 822 | " train_acc_list.append(acc) \n", 823 | "\n", 824 | " # compute the loss\n", 825 | " loss = F.cross_entropy(logits, gt) # using cross entropy \n", 826 | " train_loss_list.append(loss.item()) \n", 827 | "\n", 828 | " # backward propagation\n", 829 | " optimizer.zero_grad()\n", 830 | " loss.backward()\n", 831 | " optimizer.step()\n", 832 | "\n", 833 | " # calculate acc and loss for each epoch\n", 834 | " train_loss_epoch = np.array(train_loss_list).mean()\n", 835 | " train_acc_epoch = np.array(train_acc_list).mean()\n", 836 | "\n", 837 | " train_loss_all.append(train_loss_epoch)\n", 838 | " train_acc_all.append(train_acc_epoch)\n", 839 | "\n", 840 | " print(\"Epoch {:03d} train | Accuracy: {:.4f} | Loss: {:.4f}\".format(\\\n", 841 | " epoch+1, train_acc_epoch, train_loss_epoch))\n", 842 | "\n", 843 | "\n", 844 | " #### start evaluation\n", 845 | " val_acc_list, val_loss_list = [], []\n", 846 | "\n", 847 | " for batch, subgraph in enumerate(valid_dataloader):\n", 848 | " subgraph = subgraph.to(device)\n", 849 | "\n", 850 | " # calculate the accuracy and loss\n", 851 | " nfeat = subgraph.ndata['feat'].float()\n", 852 | " efeat = subgraph.edata['relation'].float()\n", 853 | "\n", 854 | " acc, loss, _, _, _, _ = evalEdge(model, nfeat, efeat, subgraph, subgraph.ndata['label'], n_classes)\n", 855 | "\n", 856 | " # obtain acc and loss\n", 857 | " val_acc_list.append(acc)\n", 858 | " val_loss_list.append(loss.item())\n", 859 | "\n", 860 | " # calculate the loss and acc for all graphs in one epoch\n", 861 | " val_loss_epoch = np.array(val_loss_list).mean()\n", 862 | " val_acc_epoch = np.array(val_acc_list).mean()\n", 863 | "\n", 864 | " # append for drawing the curs\n", 865 | " val_acc_all.append(val_acc_epoch)\n", 866 | " val_loss_all.append(val_loss_epoch)\n", 867 | " \n", 868 | " print(\"Validation | Accuracy: {:.4f} | Loss: {:.4f}\\n\".format(val_acc_epoch, val_loss_epoch))\n", 869 | "\n", 870 | " ############ save the best acc epoch ############\n", 871 | " if val_acc_epoch >= max(val_acc_all):\n", 872 | " torch.save(model, \"best_user.pt\")\n", 873 | " \n", 874 | "train_endtime = time.time()\n", 875 | "\n", 876 | "print(\"Finish training! Using {:.4f} s:\".format(train_endtime - train_startime))\n" 877 | ] 878 | }, 879 | { 880 | "cell_type": "markdown", 881 | "metadata": {}, 882 | "source": [ 883 | "### Start to test \n", 884 | "\n", 885 | "We prepare our best weight file here. (named as \"best.pt\")\n", 886 | "The test accuracy based on our weight file is around 80%. \n", 887 | "\n", 888 | "You can also set your criterion to select your best weight, and test here. " 889 | ] 890 | }, 891 | { 892 | "cell_type": "code", 893 | "execution_count": 7, 894 | "metadata": {}, 895 | "outputs": [ 896 | { 897 | "name": "stdout", 898 | "output_type": "stream", 899 | "text": [ 900 | "Load user best weight\n", 901 | "Test Accuracy: 0.7970\n", 902 | "F1 score: 0.780126\n", 903 | "Test time: 0.2603 s\n", 904 | "Confusion matrix:\n", 905 | "[[35 0 9 0 0 0 0 0 1]\n", 906 | " [ 1 44 0 0 0 0 0 0 0]\n", 907 | " [10 0 34 0 0 0 0 0 0]\n", 908 | " [ 0 0 0 72 6 3 0 0 0]\n", 909 | " [ 0 0 0 3 27 2 0 0 0]\n", 910 | " [ 1 0 0 17 0 46 3 0 3]\n", 911 | " [ 0 0 0 2 0 17 8 0 0]\n", 912 | " [ 0 0 0 0 0 0 0 41 0]\n", 913 | " [ 0 0 0 4 0 4 0 0 22]]\n" 914 | ] 915 | } 916 | ], 917 | "source": [ 918 | "if not os.path.exists(\"best_user.pt\"):\n", 919 | " filename = \"best_default.pt\"\n", 920 | " print(\"Load default best weight\")\n", 921 | "else:\n", 922 | " filename=\"best_user.pt\"\n", 923 | " print(\"Load user best weight\")\n", 924 | "\n", 925 | "\n", 926 | "model = torch.load(filename) # read the best weight\n", 927 | "model.eval() \n", 928 | "\n", 929 | "# print(model)\n", 930 | "\n", 931 | "test_acc_list = [] # list for storing the acc from each graph\n", 932 | "pre, gt = [], []\n", 933 | "\n", 934 | "\n", 935 | "test_startime = time.time()\n", 936 | "\n", 937 | "for batch, subgraph in enumerate(test_dataloader):\n", 938 | " subgraph = subgraph.to(device)\n", 939 | " # subgraph = dgl.add_self_loop(subgraph)\n", 940 | "\n", 941 | " nfeat = subgraph.ndata['feat'].float()\n", 942 | " efeat = subgraph.edata['relation'].float()\n", 943 | "\n", 944 | " acc, _, _, _, one_pre, one_gt = evalEdge(model, nfeat, efeat, \\\n", 945 | " subgraph, subgraph.ndata['label'], n_classes)\n", 946 | "\n", 947 | " test_acc_list.append(acc)\n", 948 | " pre.extend(one_pre)\n", 949 | " gt.extend(one_gt)\n", 950 | "\n", 951 | "\n", 952 | "\n", 953 | "test_time = time.time() - test_startime\n", 954 | "\n", 955 | "test_acc = np.array(test_acc_list).mean()\n", 956 | "\n", 957 | "cm = confusion_matrix(gt, pre) # confusion matrix, default function from scikit-learn\n", 958 | "\n", 959 | "f1 = f1_score(gt, pre, average='macro') # f1, default function from scikit-learn\n", 960 | "\n", 961 | "print(\"Test Accuracy: {:.4f}\".format(test_acc))\n", 962 | "\n", 963 | "print(\"F1 score: {:4f}\".format(f1))\n", 964 | "\n", 965 | "print(\"Test time: {:.4f} s\".format(test_time))\n", 966 | "\n", 967 | "print(f\"Confusion matrix:\\n{cm}\")" 968 | ] 969 | } 970 | ], 971 | "metadata": { 972 | "interpreter": { 973 | "hash": "7f7df54c7b13a7ddac48d89097bf9b5f3d407805e73bd908f10313d9de5254a8" 974 | }, 975 | "kernelspec": { 976 | "display_name": "Python 3.6.12 64-bit ('gnn': conda)", 977 | "name": "python3" 978 | }, 979 | "language_info": { 980 | "codemirror_mode": { 981 | "name": "ipython", 982 | "version": 3 983 | }, 984 | "file_extension": ".py", 985 | "mimetype": "text/x-python", 986 | "name": "python", 987 | "nbconvert_exporter": "python", 988 | "pygments_lexer": "ipython3", 989 | "version": "3.6.13" 990 | }, 991 | "orig_nbformat": 4 992 | }, 993 | "nbformat": 4, 994 | "nbformat_minor": 2 995 | } 996 | --------------------------------------------------------------------------------